torchtextclassifiers 1.0.2__tar.gz → 1.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/PKG-INFO +2 -2
  2. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/README.md +1 -1
  3. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/pyproject.toml +1 -1
  4. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/__init__.py +1 -0
  5. torchtextclassifiers-1.0.4/torchTextClassifiers/model/components/text_embedder.py +401 -0
  6. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/lightning.py +1 -0
  7. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/model.py +52 -11
  8. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/torchTextClassifiers.py +59 -23
  9. torchtextclassifiers-1.0.2/torchTextClassifiers/model/components/text_embedder.py +0 -223
  10. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/__init__.py +0 -0
  11. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/__init__.py +0 -0
  12. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/dataset/dataset.py +0 -0
  13. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/__init__.py +0 -0
  14. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/attention.py +0 -0
  15. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/categorical_var_net.py +0 -0
  16. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/model/components/classification_head.py +0 -0
  17. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/WordPiece.py +0 -0
  18. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/__init__.py +0 -0
  19. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/base.py +0 -0
  20. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/tokenizers/ngram.py +0 -0
  21. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/utilities/__init__.py +0 -0
  22. {torchtextclassifiers-1.0.2 → torchtextclassifiers-1.0.4}/torchTextClassifiers/utilities/plot_explainability.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: torchtextclassifiers
3
- Version: 1.0.2
3
+ Version: 1.0.4
4
4
  Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
5
  Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
6
  Author: Cédric Couralet, Meilame Tayebjee
@@ -49,7 +49,7 @@ A unified, extensible framework for text classification with categorical variabl
49
49
  ```bash
50
50
  # Clone the repository
51
51
  git clone https://github.com/InseeFrLab/torchTextClassifiers.git
52
- cd torchtextClassifiers
52
+ cd torchTextClassifiers
53
53
 
54
54
  # Install with uv (recommended)
55
55
  uv sync
@@ -23,7 +23,7 @@ A unified, extensible framework for text classification with categorical variabl
23
23
  ```bash
24
24
  # Clone the repository
25
25
  git clone https://github.com/InseeFrLab/torchTextClassifiers.git
26
- cd torchtextClassifiers
26
+ cd torchTextClassifiers
27
27
 
28
28
  # Install with uv (recommended)
29
29
  uv sync
@@ -18,7 +18,7 @@ dependencies = [
18
18
  "pytorch-lightning>=2.4.0",
19
19
  ]
20
20
  requires-python = ">=3.11"
21
- version="1.0.2"
21
+ version="1.0.4"
22
22
 
23
23
 
24
24
  [dependency-groups]
@@ -8,5 +8,6 @@ from .categorical_var_net import (
8
8
  CategoricalVariableNet as CategoricalVariableNet,
9
9
  )
10
10
  from .classification_head import ClassificationHead as ClassificationHead
11
+ from .text_embedder import LabelAttentionConfig as LabelAttentionConfig
11
12
  from .text_embedder import TextEmbedder as TextEmbedder
12
13
  from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
@@ -0,0 +1,401 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+
9
+ from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
10
+
11
+
12
+ @dataclass
13
+ class LabelAttentionConfig:
14
+ n_head: int
15
+ num_classes: int
16
+
17
+
18
+ @dataclass
19
+ class TextEmbedderConfig:
20
+ vocab_size: int
21
+ embedding_dim: int
22
+ padding_idx: int
23
+ attention_config: Optional[AttentionConfig] = None
24
+ label_attention_config: Optional[LabelAttentionConfig] = None
25
+
26
+
27
+ class TextEmbedder(nn.Module):
28
+ def __init__(self, text_embedder_config: TextEmbedderConfig):
29
+ super().__init__()
30
+
31
+ self.config = text_embedder_config
32
+
33
+ self.attention_config = text_embedder_config.attention_config
34
+ if isinstance(self.attention_config, dict):
35
+ self.attention_config = AttentionConfig(**self.attention_config)
36
+
37
+ # Normalize label_attention_config: allow dicts and convert them to LabelAttentionConfig
38
+ self.label_attention_config = text_embedder_config.label_attention_config
39
+ if isinstance(self.label_attention_config, dict):
40
+ self.label_attention_config = LabelAttentionConfig(**self.label_attention_config)
41
+ # Keep self.config in sync so downstream components (e.g., LabelAttentionClassifier)
42
+ # always see a LabelAttentionConfig instance rather than a raw dict.
43
+ self.config.label_attention_config = self.label_attention_config
44
+
45
+ self.enable_label_attention = self.label_attention_config is not None
46
+ if self.enable_label_attention:
47
+ self.label_attention_module = LabelAttentionClassifier(self.config)
48
+
49
+ self.vocab_size = text_embedder_config.vocab_size
50
+ self.embedding_dim = text_embedder_config.embedding_dim
51
+ self.padding_idx = text_embedder_config.padding_idx
52
+
53
+ self.embedding_layer = nn.Embedding(
54
+ embedding_dim=self.embedding_dim,
55
+ num_embeddings=self.vocab_size,
56
+ padding_idx=self.padding_idx,
57
+ )
58
+
59
+ if self.attention_config is not None:
60
+ self.attention_config.n_embd = text_embedder_config.embedding_dim
61
+ self.transformer = nn.ModuleDict(
62
+ {
63
+ "h": nn.ModuleList(
64
+ [
65
+ Block(self.attention_config, layer_idx)
66
+ for layer_idx in range(self.attention_config.n_layers)
67
+ ]
68
+ ),
69
+ }
70
+ )
71
+
72
+ head_dim = self.attention_config.n_embd // self.attention_config.n_head
73
+
74
+ if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
75
+ raise ValueError("embedding_dim must be divisible by n_head.")
76
+
77
+ if self.attention_config.positional_encoding:
78
+ if head_dim % 2 != 0:
79
+ raise ValueError(
80
+ "embedding_dim / n_head must be even for rotary positional embeddings."
81
+ )
82
+
83
+ if self.attention_config.sequence_len is None:
84
+ raise ValueError(
85
+ "sequence_len must be specified in AttentionConfig when positional_encoding is True."
86
+ )
87
+
88
+ self.rotary_seq_len = self.attention_config.sequence_len * 10
89
+ cos, sin = self._precompute_rotary_embeddings(
90
+ seq_len=self.rotary_seq_len, head_dim=head_dim
91
+ )
92
+
93
+ self.register_buffer(
94
+ "cos", cos, persistent=False
95
+ ) # persistent=False means it's not saved to the checkpoint
96
+ self.register_buffer("sin", sin, persistent=False)
97
+
98
+ def init_weights(self):
99
+ self.apply(self._init_weights)
100
+
101
+ # zero out c_proj weights in all blocks
102
+ if self.attention_config is not None:
103
+ for block in self.transformer.h:
104
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
105
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
106
+ # init the rotary embeddings
107
+ head_dim = self.attention_config.n_embd // self.attention_config.n_head
108
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
109
+ self.cos, self.sin = cos, sin
110
+ # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
111
+ if self.embedding_layer.weight.device.type == "cuda":
112
+ self.embedding_layer.to(dtype=torch.bfloat16)
113
+
114
+ def _init_weights(self, module):
115
+ if isinstance(module, nn.Linear):
116
+ # https://arxiv.org/pdf/2310.17813
117
+ fan_out = module.weight.size(0)
118
+ fan_in = module.weight.size(1)
119
+ std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
120
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
121
+ if module.bias is not None:
122
+ torch.nn.init.zeros_(module.bias)
123
+ elif isinstance(module, nn.Embedding):
124
+ torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: torch.Tensor,
129
+ attention_mask: torch.Tensor,
130
+ return_label_attention_matrix: bool = False,
131
+ ) -> Dict[str, Optional[torch.Tensor]]:
132
+ """Converts input token IDs to their corresponding embeddings.
133
+
134
+ Args:
135
+ input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized
136
+ attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
137
+ return_label_attention_matrix (bool): Whether to return the label attention matrix.
138
+
139
+ Returns:
140
+ dict: A dictionary with the following keys:
141
+
142
+ - "sentence_embedding" (torch.Tensor): Text embeddings of shape
143
+ (batch_size, embedding_dim) if ``self.enable_label_attention`` is False,
144
+ else (batch_size, num_classes, embedding_dim), where ``num_classes``
145
+ is the number of label classes.
146
+
147
+ - "label_attention_matrix" (Optional[torch.Tensor]): Label attention
148
+ matrix of shape (batch_size, n_head, num_classes, seq_len) if
149
+ ``return_label_attention_matrix`` is True and label attention is
150
+ enabled, otherwise ``None``. The dimensions correspond to
151
+ (batch_size, attention heads, label classes, sequence length).
152
+ """
153
+
154
+ encoded_text = input_ids # clearer name
155
+ if encoded_text.dtype != torch.long:
156
+ encoded_text = encoded_text.to(torch.long)
157
+
158
+ batch_size, seq_len = encoded_text.shape
159
+ batch_size_check, seq_len_check = attention_mask.shape
160
+
161
+ if batch_size != batch_size_check or seq_len != seq_len_check:
162
+ raise ValueError(
163
+ f"Input IDs and attention mask must have the same batch size and sequence length. "
164
+ f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
165
+ )
166
+
167
+ token_embeddings = self.embedding_layer(
168
+ encoded_text
169
+ ) # (batch_size, seq_len, embedding_dim)
170
+
171
+ token_embeddings = norm(token_embeddings)
172
+
173
+ if self.attention_config is not None:
174
+ if self.attention_config.positional_encoding:
175
+ cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
176
+ else:
177
+ cos_sin = None
178
+
179
+ for block in self.transformer.h:
180
+ token_embeddings = block(token_embeddings, cos_sin)
181
+
182
+ token_embeddings = norm(token_embeddings)
183
+
184
+ out = self._get_sentence_embedding(
185
+ token_embeddings=token_embeddings,
186
+ attention_mask=attention_mask,
187
+ return_label_attention_matrix=return_label_attention_matrix,
188
+ )
189
+
190
+ text_embedding = out["sentence_embedding"]
191
+ label_attention_matrix = out["label_attention_matrix"]
192
+ return {
193
+ "sentence_embedding": text_embedding,
194
+ "label_attention_matrix": label_attention_matrix,
195
+ }
196
+
197
+ def _get_sentence_embedding(
198
+ self,
199
+ token_embeddings: torch.Tensor,
200
+ attention_mask: torch.Tensor,
201
+ return_label_attention_matrix: bool = False,
202
+ ) -> Dict[str, Optional[torch.Tensor]]:
203
+ """
204
+ Compute sentence embedding from embedded tokens - "remove" second dimension.
205
+
206
+ Args (output from dataset collate_fn):
207
+ token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
208
+ attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
209
+ return_label_attention_matrix (bool): Whether to compute and return the label attention matrix
210
+ Returns:
211
+ Dict[str, Optional[torch.Tensor]]: A dictionary containing:
212
+ - 'sentence_embedding': Sentence embeddings, shape (batch_size, embedding_dim) or (batch_size, n_labels, embedding_dim) if label attention is enabled
213
+ - 'label_attention_matrix': Attention matrix if label attention is enabled and return_label_attention_matrix is True, otherwise None
214
+ """
215
+
216
+ # average over non-pad token embeddings
217
+ # attention mask has 1 for non-pad tokens and 0 for pad token positions
218
+
219
+ # mask pad-tokens
220
+
221
+ if self.attention_config is not None:
222
+ if self.attention_config.aggregation_method is not None: # default is "mean"
223
+ if self.attention_config.aggregation_method == "first":
224
+ return {
225
+ "sentence_embedding": token_embeddings[:, 0, :],
226
+ "label_attention_matrix": None,
227
+ }
228
+ elif self.attention_config.aggregation_method == "last":
229
+ lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
230
+ return {
231
+ "sentence_embedding": token_embeddings[
232
+ torch.arange(token_embeddings.size(0)),
233
+ lengths - 1,
234
+ :,
235
+ ],
236
+ "label_attention_matrix": None,
237
+ }
238
+ else:
239
+ if self.attention_config.aggregation_method != "mean":
240
+ raise ValueError(
241
+ f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
242
+ )
243
+
244
+ assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
245
+
246
+ if self.enable_label_attention:
247
+ label_attention_result = self.label_attention_module(
248
+ token_embeddings,
249
+ attention_mask=attention_mask,
250
+ compute_attention_matrix=return_label_attention_matrix,
251
+ )
252
+ sentence_embedding = label_attention_result[
253
+ "sentence_embedding"
254
+ ] # (bs, n_labels, d_embed), so classifier needs to be a (d_embed, 1) matrix
255
+ label_attention_matrix = label_attention_result["attention_matrix"]
256
+
257
+ else: # sentence embedding = mean of (non-pad) token embeddings
258
+ mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
259
+ masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
260
+ sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
261
+ min=1.0
262
+ ) # avoid division by zero
263
+
264
+ sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
265
+ label_attention_matrix = None
266
+
267
+ return {
268
+ "sentence_embedding": sentence_embedding,
269
+ "label_attention_matrix": label_attention_matrix,
270
+ }
271
+
272
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
273
+ # autodetect the device from model embeddings
274
+ if device is None:
275
+ device = next(self.parameters()).device
276
+
277
+ # stride the channels
278
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
279
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
280
+ # stride the time steps
281
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
282
+ # calculate the rotation frequencies at each (time, channel) pair
283
+ freqs = torch.outer(t, inv_freq)
284
+ cos, sin = freqs.cos(), freqs.sin()
285
+ cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
286
+ cos, sin = (
287
+ cos[None, :, None, :],
288
+ sin[None, :, None, :],
289
+ ) # add batch and head dims for later broadcasting
290
+
291
+ return cos, sin
292
+
293
+
294
+ class LabelAttentionClassifier(nn.Module):
295
+ """
296
+ A head for aggregating token embeddings into label-specific sentence embeddings using cross-attention mechanism.
297
+ Labels are queries that attend over token embeddings (keys and values) to produce label-specific embeddings.
298
+
299
+ """
300
+
301
+ def __init__(self, config: TextEmbedderConfig):
302
+ super().__init__()
303
+
304
+ label_attention_config = config.label_attention_config
305
+ self.embedding_dim = config.embedding_dim
306
+ self.num_classes = label_attention_config.num_classes
307
+ self.n_head = label_attention_config.n_head
308
+
309
+ # Validate head configuration
310
+ self.head_dim = self.embedding_dim // self.n_head
311
+
312
+ if self.head_dim * self.n_head != self.embedding_dim:
313
+ raise ValueError(
314
+ f"embedding_dim ({self.embedding_dim}) must be divisible by n_head ({self.n_head}). "
315
+ f"Got head_dim = {self.head_dim} with remainder {self.embedding_dim % self.n_head}"
316
+ )
317
+
318
+ self.label_embeds = nn.Embedding(self.num_classes, self.embedding_dim)
319
+
320
+ self.c_q = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
321
+ self.c_k = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
322
+ self.c_v = nn.Linear(self.embedding_dim, self.n_head * self.head_dim, bias=False)
323
+ self.c_proj = nn.Linear(self.embedding_dim, self.embedding_dim, bias=False)
324
+
325
+ def forward(
326
+ self,
327
+ token_embeddings,
328
+ attention_mask: Optional[torch.Tensor] = None,
329
+ compute_attention_matrix: Optional[bool] = False,
330
+ ):
331
+ """
332
+ Args:
333
+ token_embeddings (torch.Tensor), shape (batch, seq_len, d_model): Embedded tokens from the text input.
334
+ attention_mask (torch.Tensor, optional), shape (batch, seq_len): Attention mask indicating non-pad tokens (1 for real tokens, 0 for padding).
335
+ compute_attention_matrix (bool): Whether to compute and return the attention matrix.
336
+ Returns:
337
+ dict: {
338
+ "sentence_embedding": torch.Tensor, shape (batch, num_classes, d_model): Label-specific sentence embeddings.
339
+ "attention_matrix": Optional[torch.Tensor], shape (batch, n_head, num_classes, seq_len): Attention weights if compute_attention_matrix is True, else None.
340
+ }
341
+
342
+ """
343
+ B, T, C = token_embeddings.size()
344
+ if isinstance(compute_attention_matrix, torch.Tensor):
345
+ compute_attention_matrix = compute_attention_matrix[0].item()
346
+ compute_attention_matrix = bool(compute_attention_matrix)
347
+
348
+ # 1. Create label indices [0, 1, ..., C-1] for the whole batch
349
+ label_indices = torch.arange(
350
+ self.num_classes, dtype=torch.long, device=token_embeddings.device
351
+ ).expand(B, -1)
352
+
353
+ all_label_embeddings = self.label_embeds(
354
+ label_indices
355
+ ) # Shape: [batch, num_classes, d_model]
356
+ all_label_embeddings = norm(all_label_embeddings)
357
+
358
+ q = self.c_q(all_label_embeddings).view(B, self.num_classes, self.n_head, self.head_dim)
359
+ k = self.c_k(token_embeddings).view(B, T, self.n_head, self.head_dim)
360
+ v = self.c_v(token_embeddings).view(B, T, self.n_head, self.head_dim)
361
+
362
+ q, k = norm(q), norm(k) # QK norm
363
+ q, k, v = (
364
+ q.transpose(1, 2),
365
+ k.transpose(1, 2),
366
+ v.transpose(1, 2),
367
+ ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
368
+
369
+ # Prepare attention mask for scaled_dot_product_attention
370
+ # attention_mask: (B, T) with 1 for real tokens, 0 for padding
371
+ # scaled_dot_product_attention expects attn_mask: (B, H, Q, K) or broadcastable shape
372
+ # where True means "mask out" (ignore), False means "attend to"
373
+ attn_mask = None
374
+ if attention_mask is not None:
375
+ # Convert: 0 (padding) -> True (mask out), 1 (real) -> False (attend to)
376
+ attn_mask = attention_mask == 0 # (B, T)
377
+ # Expand to (B, 1, 1, T) for broadcasting across heads and queries
378
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, T)
379
+
380
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False)
381
+
382
+ # Re-assemble the heads side by side and project back to residual stream
383
+ y = y.transpose(1, 2).contiguous().view(B, self.num_classes, -1) # (bs, n_labels, d_model)
384
+ y = self.c_proj(y)
385
+
386
+ attention_matrix = None
387
+ if compute_attention_matrix:
388
+ # Compute attention scores
389
+ # size (B, n_head, n_labels, seq_len)
390
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
391
+
392
+ # Apply mask to attention scores before softmax
393
+ if attention_mask is not None:
394
+ # attn_mask is already in the right shape: (B, 1, 1, T)
395
+ # We need to apply it to scores of shape (B, n_head, n_labels, T)
396
+ # Set masked positions to -inf so they become 0 after softmax
397
+ attention_scores = attention_scores.masked_fill(attn_mask, float("-inf"))
398
+
399
+ attention_matrix = torch.softmax(attention_scores, dim=-1)
400
+
401
+ return {"sentence_embedding": y, "attention_matrix": attention_matrix}
@@ -102,6 +102,7 @@ class TextClassificationModule(pl.LightningModule):
102
102
  targets = batch["labels"]
103
103
 
104
104
  outputs = self.forward(batch)
105
+
105
106
  loss = self.loss(outputs, targets)
106
107
  self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, sync_dist=True)
107
108
 
@@ -1,12 +1,12 @@
1
- """FastText model components.
1
+ """TextClassification model components.
2
2
 
3
3
  This module contains the PyTorch model, Lightning module, and dataset classes
4
- for FastText classification. Consolidates what was previously in pytorch_model.py,
4
+ for text classification. Consolidates what was previously in pytorch_model.py,
5
5
  lightning_module.py, and dataset.py.
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import Annotated, Optional
9
+ from typing import Annotated, Optional, Union
10
10
 
11
11
  import torch
12
12
  from torch import nn
@@ -17,6 +17,7 @@ from torchTextClassifiers.model.components import (
17
17
  ClassificationHead,
18
18
  TextEmbedder,
19
19
  )
20
+ from torchTextClassifiers.model.components.attention import norm
20
21
 
21
22
  logger = logging.getLogger(__name__)
22
23
 
@@ -67,8 +68,6 @@ class TextClassificationModel(nn.Module):
67
68
 
68
69
  self._validate_component_connections()
69
70
 
70
- self.num_classes = self.classification_head.num_classes
71
-
72
71
  torch.nn.init.zeros_(self.classification_head.net.weight)
73
72
  if self.text_embedder is not None:
74
73
  self.text_embedder.init_weights()
@@ -98,6 +97,17 @@ class TextClassificationModel(nn.Module):
98
97
  raise ValueError(
99
98
  "Classification head input dimension does not match expected dimension from text embedder and categorical variable net."
100
99
  )
100
+ if self.text_embedder.enable_label_attention:
101
+ self.enable_label_attention = True
102
+ if self.classification_head.num_classes != 1:
103
+ raise ValueError(
104
+ "Label attention is enabled. TextEmbedder outputs a (num_classes, embedding_dim) tensor, so the ClassificationHead should have an output dimension of 1."
105
+ )
106
+ # if enable_label_attention is True, label_attention_config exists - and contains num_classes necessarily
107
+ self.num_classes = self.text_embedder.config.label_attention_config.num_classes
108
+ else:
109
+ self.enable_label_attention = False
110
+ self.num_classes = self.classification_head.num_classes
101
111
  else:
102
112
  logger.warning(
103
113
  "⚠️ No text embedder provided; assuming input text is already embedded or vectorized. Take care that the classification head input dimension matches the input text dimension."
@@ -108,8 +118,9 @@ class TextClassificationModel(nn.Module):
108
118
  input_ids: Annotated[torch.Tensor, "batch seq_len"],
109
119
  attention_mask: Annotated[torch.Tensor, "batch seq_len"],
110
120
  categorical_vars: Annotated[torch.Tensor, "batch num_cats"],
121
+ return_label_attention_matrix: bool = False,
111
122
  **kwargs,
112
- ) -> torch.Tensor:
123
+ ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
113
124
  """
114
125
  Memory-efficient forward pass implementation.
115
126
 
@@ -117,35 +128,65 @@ class TextClassificationModel(nn.Module):
117
128
  input_ids (torch.Tensor[Long]), shape (batch_size, seq_len): Tokenized + padded text
118
129
  attention_mask (torch.Tensor[int]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
119
130
  categorical_vars (torch.Tensor[Long]): Additional categorical features, (batch_size, num_categorical_features)
131
+ return_label_attention_matrix (bool): If True, returns a dict with logits and label_attention_matrix
120
132
 
121
133
  Returns:
122
- torch.Tensor: Model output scores for each class - shape (batch_size, num_classes)
123
- Raw, not softmaxed.
134
+ Union[torch.Tensor, dict[str, torch.Tensor]]:
135
+ - If return_label_attention_matrix is False: torch.Tensor of shape (batch_size, num_classes)
136
+ containing raw logits (not softmaxed)
137
+ - If return_label_attention_matrix is True: dict with keys:
138
+ - "logits": torch.Tensor of shape (batch_size, num_classes)
139
+ - "label_attention_matrix": torch.Tensor of shape (batch_size, num_classes, seq_len)
124
140
  """
125
141
  encoded_text = input_ids # clearer name
142
+ label_attention_matrix = None
126
143
  if self.text_embedder is None:
127
144
  x_text = encoded_text.float()
145
+ if return_label_attention_matrix:
146
+ raise ValueError(
147
+ "return_label_attention_matrix=True requires a text_embedder with label attention enabled"
148
+ )
128
149
  else:
129
- x_text = self.text_embedder(input_ids=encoded_text, attention_mask=attention_mask)
150
+ text_embed_output = self.text_embedder(
151
+ input_ids=encoded_text,
152
+ attention_mask=attention_mask,
153
+ return_label_attention_matrix=return_label_attention_matrix,
154
+ )
155
+ x_text = text_embed_output["sentence_embedding"]
156
+ if isinstance(return_label_attention_matrix, torch.Tensor):
157
+ return_label_attention_matrix = return_label_attention_matrix[0].item()
158
+ if return_label_attention_matrix:
159
+ label_attention_matrix = text_embed_output["label_attention_matrix"]
130
160
 
131
161
  if self.categorical_variable_net:
132
162
  x_cat = self.categorical_variable_net(categorical_vars)
133
163
 
164
+ if self.enable_label_attention:
165
+ # x_text is (batch_size, num_classes, embedding_dim)
166
+ # x_cat is (batch_size, cat_embedding_dim)
167
+ # We need to expand x_cat to (batch_size, num_classes, cat_embedding_dim)
168
+ # x_cat will be appended to x_text along the last dimension for each class
169
+ x_cat = x_cat.unsqueeze(1).expand(-1, self.num_classes, -1)
170
+
134
171
  if (
135
172
  self.categorical_variable_net.forward_type
136
173
  == CategoricalForwardType.AVERAGE_AND_CONCAT
137
174
  or self.categorical_variable_net.forward_type
138
175
  == CategoricalForwardType.CONCATENATE_ALL
139
176
  ):
140
- x_combined = torch.cat((x_text, x_cat), dim=1)
177
+ x_combined = torch.cat((x_text, x_cat), dim=-1)
141
178
  else:
142
179
  assert (
143
180
  self.categorical_variable_net.forward_type == CategoricalForwardType.SUM_TO_TEXT
144
181
  )
182
+
145
183
  x_combined = x_text + x_cat
146
184
  else:
147
185
  x_combined = x_text
148
186
 
149
- logits = self.classification_head(x_combined)
187
+ logits = self.classification_head(norm(x_combined)).squeeze(-1)
188
+
189
+ if return_label_attention_matrix:
190
+ return {"logits": logits, "label_attention_matrix": label_attention_matrix}
150
191
 
151
192
  return logits
@@ -29,6 +29,7 @@ from torchTextClassifiers.model.components import (
29
29
  CategoricalForwardType,
30
30
  CategoricalVariableNet,
31
31
  ClassificationHead,
32
+ LabelAttentionConfig,
32
33
  TextEmbedder,
33
34
  TextEmbedderConfig,
34
35
  )
@@ -53,6 +54,7 @@ class ModelConfig:
53
54
  categorical_embedding_dims: Optional[Union[List[int], int]] = None
54
55
  num_classes: Optional[int] = None
55
56
  attention_config: Optional[AttentionConfig] = None
57
+ label_attention_config: Optional[LabelAttentionConfig] = None
56
58
 
57
59
  def to_dict(self) -> Dict[str, Any]:
58
60
  return asdict(self)
@@ -140,6 +142,7 @@ class torchTextClassifiers:
140
142
  self.embedding_dim = model_config.embedding_dim
141
143
  self.categorical_vocabulary_sizes = model_config.categorical_vocabulary_sizes
142
144
  self.num_classes = model_config.num_classes
145
+ self.enable_label_attention = model_config.label_attention_config is not None
143
146
 
144
147
  if self.tokenizer.output_vectorized:
145
148
  self.text_embedder = None
@@ -153,6 +156,7 @@ class torchTextClassifiers:
153
156
  embedding_dim=self.embedding_dim,
154
157
  padding_idx=tokenizer.padding_idx,
155
158
  attention_config=model_config.attention_config,
159
+ label_attention_config=model_config.label_attention_config,
156
160
  )
157
161
  self.text_embedder = TextEmbedder(
158
162
  text_embedder_config=text_embedder_config,
@@ -174,7 +178,9 @@ class torchTextClassifiers:
174
178
 
175
179
  self.classification_head = ClassificationHead(
176
180
  input_dim=classif_head_input_dim,
177
- num_classes=model_config.num_classes,
181
+ num_classes=1
182
+ if self.enable_label_attention
183
+ else model_config.num_classes, # output dim is 1 when using label attention, because embeddings are (num_classes, embedding_dim)
178
184
  )
179
185
 
180
186
  self.pytorch_model = TextClassificationModel(
@@ -486,13 +492,15 @@ class torchTextClassifiers:
486
492
  self,
487
493
  X_test: np.ndarray,
488
494
  top_k=1,
489
- explain=False,
495
+ explain_with_label_attention: bool = False,
496
+ explain_with_captum=False,
490
497
  ):
491
498
  """
492
499
  Args:
493
500
  X_test (np.ndarray): input data to predict on, shape (N,d) where the first column is text and the rest are categorical variables
494
501
  top_k (int): for each sentence, return the top_k most likely predictions (default: 1)
495
- explain (bool): launch gradient integration to have an explanation of the prediction (default: False)
502
+ explain_with_label_attention (bool): if enabled, use attention matrix labels x tokens to have an explanation of the prediction (default: False)
503
+ explain_with_captum (bool): launch gradient integration with Captum for explanation (default: False)
496
504
 
497
505
  Returns: A dictionary containing the following fields:
498
506
  - predictions (torch.Tensor, shape (len(text), top_k)): A tensor containing the top_k most likely codes to the query.
@@ -501,6 +509,7 @@ class torchTextClassifiers:
501
509
  - attributions (torch.Tensor, shape (len(text), top_k, seq_len)): A tensor containing the attributions for each token in the text.
502
510
  """
503
511
 
512
+ explain = explain_with_label_attention or explain_with_captum
504
513
  if explain:
505
514
  return_offsets_mapping = True # to be passed to the tokenizer
506
515
  return_word_ids = True
@@ -509,13 +518,19 @@ class torchTextClassifiers:
509
518
  "Explainability is not supported when the tokenizer outputs vectorized text directly. Please use a tokenizer that outputs token IDs."
510
519
  )
511
520
  else:
512
- if not HAS_CAPTUM:
513
- raise ImportError(
514
- "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
515
- )
516
- lig = LayerIntegratedGradients(
517
- self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
518
- ) # initialize a Captum layer gradient integrator
521
+ if explain_with_captum:
522
+ if not HAS_CAPTUM:
523
+ raise ImportError(
524
+ "Captum is not installed and is required for explainability. Run 'pip install/uv add torchFastText[explainability]'."
525
+ )
526
+ lig = LayerIntegratedGradients(
527
+ self.pytorch_model, self.pytorch_model.text_embedder.embedding_layer
528
+ ) # initialize a Captum layer gradient integrator
529
+ if explain_with_label_attention:
530
+ if not self.enable_label_attention:
531
+ raise RuntimeError(
532
+ "Label attention explainability is enabled, but the model was not configured with label attention. Please enable label attention in the model configuration during initialization and retrain."
533
+ )
519
534
  else:
520
535
  return_offsets_mapping = False
521
536
  return_word_ids = False
@@ -547,9 +562,19 @@ class torchTextClassifiers:
547
562
  else:
548
563
  categorical_vars = torch.empty((encoded_text.shape[0], 0), dtype=torch.float32)
549
564
 
550
- pred = self.pytorch_model(
551
- encoded_text, attention_mask, categorical_vars
565
+ model_output = self.pytorch_model(
566
+ encoded_text,
567
+ attention_mask,
568
+ categorical_vars,
569
+ return_label_attention_matrix=explain_with_label_attention,
552
570
  ) # forward pass, contains the prediction scores (len(text), num_classes)
571
+ pred = (
572
+ model_output["logits"] if explain_with_label_attention else model_output
573
+ ) # (batch_size, num_classes)
574
+
575
+ label_attention_matrix = (
576
+ model_output["label_attention_matrix"] if explain_with_label_attention else None
577
+ )
553
578
 
554
579
  label_scores = pred.detach().cpu().softmax(dim=1) # convert to probabilities
555
580
 
@@ -559,21 +584,28 @@ class torchTextClassifiers:
559
584
  confidence = torch.round(label_scores_topk.values, decimals=2) # and their scores
560
585
 
561
586
  if explain:
562
- all_attributions = []
563
- for k in range(top_k):
564
- attributions = lig.attribute(
565
- (encoded_text, attention_mask, categorical_vars),
566
- target=torch.Tensor(predictions[:, k]).long(),
567
- ) # (batch_size, seq_len)
568
- attributions = attributions.sum(dim=-1)
569
- all_attributions.append(attributions.detach().cpu())
570
-
571
- all_attributions = torch.stack(all_attributions, dim=1) # (batch_size, top_k, seq_len)
587
+ if explain_with_captum:
588
+ # Captum explanations
589
+ captum_attributions = []
590
+ for k in range(top_k):
591
+ attributions = lig.attribute(
592
+ (encoded_text, attention_mask, categorical_vars),
593
+ target=torch.Tensor(predictions[:, k]).long(),
594
+ ) # (batch_size, seq_len)
595
+ attributions = attributions.sum(dim=-1)
596
+ captum_attributions.append(attributions.detach().cpu())
597
+
598
+ captum_attributions = torch.stack(
599
+ captum_attributions, dim=1
600
+ ) # (batch_size, top_k, seq_len)
601
+ else:
602
+ captum_attributions = None
572
603
 
573
604
  return {
574
605
  "prediction": predictions,
575
606
  "confidence": confidence,
576
- "attributions": all_attributions,
607
+ "captum_attributions": captum_attributions,
608
+ "label_attention_attributions": label_attention_matrix,
577
609
  "offset_mapping": tokenize_output.offset_mapping,
578
610
  "word_ids": tokenize_output.word_ids,
579
611
  }
@@ -665,6 +697,10 @@ class torchTextClassifiers:
665
697
 
666
698
  # Reconstruct model_config
667
699
  model_config = ModelConfig.from_dict(metadata["model_config"])
700
+ if isinstance(model_config.label_attention_config, dict):
701
+ model_config.label_attention_config = LabelAttentionConfig(
702
+ **model_config.label_attention_config
703
+ )
668
704
 
669
705
  # Create instance
670
706
  instance = cls(
@@ -1,223 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Optional
4
-
5
- import torch
6
- from torch import nn
7
-
8
- from torchTextClassifiers.model.components.attention import AttentionConfig, Block, norm
9
-
10
-
11
- @dataclass
12
- class TextEmbedderConfig:
13
- vocab_size: int
14
- embedding_dim: int
15
- padding_idx: int
16
- attention_config: Optional[AttentionConfig] = None
17
-
18
-
19
- class TextEmbedder(nn.Module):
20
- def __init__(self, text_embedder_config: TextEmbedderConfig):
21
- super().__init__()
22
-
23
- self.config = text_embedder_config
24
-
25
- self.attention_config = text_embedder_config.attention_config
26
- if isinstance(self.attention_config, dict):
27
- self.attention_config = AttentionConfig(**self.attention_config)
28
-
29
- if self.attention_config is not None:
30
- self.attention_config.n_embd = text_embedder_config.embedding_dim
31
-
32
- self.vocab_size = text_embedder_config.vocab_size
33
- self.embedding_dim = text_embedder_config.embedding_dim
34
- self.padding_idx = text_embedder_config.padding_idx
35
-
36
- self.embedding_layer = nn.Embedding(
37
- embedding_dim=self.embedding_dim,
38
- num_embeddings=self.vocab_size,
39
- padding_idx=self.padding_idx,
40
- )
41
-
42
- if self.attention_config is not None:
43
- self.transformer = nn.ModuleDict(
44
- {
45
- "h": nn.ModuleList(
46
- [
47
- Block(self.attention_config, layer_idx)
48
- for layer_idx in range(self.attention_config.n_layers)
49
- ]
50
- ),
51
- }
52
- )
53
-
54
- head_dim = self.attention_config.n_embd // self.attention_config.n_head
55
-
56
- if head_dim * self.attention_config.n_head != self.attention_config.n_embd:
57
- raise ValueError("embedding_dim must be divisible by n_head.")
58
-
59
- if self.attention_config.positional_encoding:
60
- if head_dim % 2 != 0:
61
- raise ValueError(
62
- "embedding_dim / n_head must be even for rotary positional embeddings."
63
- )
64
-
65
- if self.attention_config.sequence_len is None:
66
- raise ValueError(
67
- "sequence_len must be specified in AttentionConfig when positional_encoding is True."
68
- )
69
-
70
- self.rotary_seq_len = self.attention_config.sequence_len * 10
71
- cos, sin = self._precompute_rotary_embeddings(
72
- seq_len=self.rotary_seq_len, head_dim=head_dim
73
- )
74
-
75
- self.register_buffer(
76
- "cos", cos, persistent=False
77
- ) # persistent=False means it's not saved to the checkpoint
78
- self.register_buffer("sin", sin, persistent=False)
79
-
80
- def init_weights(self):
81
- self.apply(self._init_weights)
82
-
83
- # zero out c_proj weights in all blocks
84
- if self.attention_config is not None:
85
- for block in self.transformer.h:
86
- torch.nn.init.zeros_(block.mlp.c_proj.weight)
87
- torch.nn.init.zeros_(block.attn.c_proj.weight)
88
- # init the rotary embeddings
89
- head_dim = self.attention_config.n_embd // self.attention_config.n_head
90
- cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
91
- self.cos, self.sin = cos, sin
92
- # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
93
- if self.embedding_layer.weight.device.type == "cuda":
94
- self.embedding_layer.to(dtype=torch.bfloat16)
95
-
96
- def _init_weights(self, module):
97
- if isinstance(module, nn.Linear):
98
- # https://arxiv.org/pdf/2310.17813
99
- fan_out = module.weight.size(0)
100
- fan_in = module.weight.size(1)
101
- std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
102
- torch.nn.init.normal_(module.weight, mean=0.0, std=std)
103
- if module.bias is not None:
104
- torch.nn.init.zeros_(module.bias)
105
- elif isinstance(module, nn.Embedding):
106
- torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
107
-
108
- def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
109
- """Converts input token IDs to their corresponding embeddings."""
110
-
111
- encoded_text = input_ids # clearer name
112
- if encoded_text.dtype != torch.long:
113
- encoded_text = encoded_text.to(torch.long)
114
-
115
- batch_size, seq_len = encoded_text.shape
116
- batch_size_check, seq_len_check = attention_mask.shape
117
-
118
- if batch_size != batch_size_check or seq_len != seq_len_check:
119
- raise ValueError(
120
- f"Input IDs and attention mask must have the same batch size and sequence length. "
121
- f"Got input_ids shape {encoded_text.shape} and attention_mask shape {attention_mask.shape}."
122
- )
123
-
124
- token_embeddings = self.embedding_layer(
125
- encoded_text
126
- ) # (batch_size, seq_len, embedding_dim)
127
-
128
- token_embeddings = norm(token_embeddings)
129
-
130
- if self.attention_config is not None:
131
- if self.attention_config.positional_encoding:
132
- cos_sin = self.cos[:, :seq_len], self.sin[:, :seq_len]
133
- else:
134
- cos_sin = None
135
-
136
- for block in self.transformer.h:
137
- token_embeddings = block(token_embeddings, cos_sin)
138
-
139
- token_embeddings = norm(token_embeddings)
140
-
141
- text_embedding = self._get_sentence_embedding(
142
- token_embeddings=token_embeddings, attention_mask=attention_mask
143
- )
144
-
145
- return text_embedding
146
-
147
- def _get_sentence_embedding(
148
- self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
149
- ) -> torch.Tensor:
150
- """
151
- Compute sentence embedding from embedded tokens - "remove" second dimension.
152
-
153
- Args (output from dataset collate_fn):
154
- token_embeddings (torch.Tensor[Long]), shape (batch_size, seq_len, embedding_dim): Tokenized + padded text
155
- attention_mask (torch.Tensor[Long]), shape (batch_size, seq_len): Attention mask indicating non-pad tokens
156
- Returns:
157
- torch.Tensor: Sentence embeddings, shape (batch_size, embedding_dim)
158
- """
159
-
160
- # average over non-pad token embeddings
161
- # attention mask has 1 for non-pad tokens and 0 for pad token positions
162
-
163
- # mask pad-tokens
164
-
165
- if self.attention_config is not None:
166
- if self.attention_config.aggregation_method is not None:
167
- if self.attention_config.aggregation_method == "first":
168
- return token_embeddings[:, 0, :]
169
- elif self.attention_config.aggregation_method == "last":
170
- lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
171
- return token_embeddings[
172
- torch.arange(token_embeddings.size(0)),
173
- lengths - 1,
174
- :,
175
- ]
176
- else:
177
- if self.attention_config.aggregation_method != "mean":
178
- raise ValueError(
179
- f"Unknown aggregation method: {self.attention_config.aggregation_method}. Supported methods are 'mean', 'first', 'last'."
180
- )
181
-
182
- assert self.attention_config is None or self.attention_config.aggregation_method == "mean"
183
-
184
- mask = attention_mask.unsqueeze(-1).float() # (batch_size, seq_len, 1)
185
- masked_embeddings = token_embeddings * mask # (batch_size, seq_len, embedding_dim)
186
-
187
- sentence_embedding = masked_embeddings.sum(dim=1) / mask.sum(dim=1).clamp(
188
- min=1.0
189
- ) # avoid division by zero
190
-
191
- sentence_embedding = torch.nan_to_num(sentence_embedding, 0.0)
192
-
193
- return sentence_embedding
194
-
195
- def __call__(self, *args, **kwargs):
196
- out = super().__call__(*args, **kwargs)
197
- if out.dim() != 2:
198
- raise ValueError(
199
- f"Output of {self.__class__.__name__}.forward must be 2D "
200
- f"(got shape {tuple(out.shape)})"
201
- )
202
- return out
203
-
204
- def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
205
- # autodetect the device from model embeddings
206
- if device is None:
207
- device = next(self.parameters()).device
208
-
209
- # stride the channels
210
- channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
211
- inv_freq = 1.0 / (base ** (channel_range / head_dim))
212
- # stride the time steps
213
- t = torch.arange(seq_len, dtype=torch.float32, device=device)
214
- # calculate the rotation frequencies at each (time, channel) pair
215
- freqs = torch.outer(t, inv_freq)
216
- cos, sin = freqs.cos(), freqs.sin()
217
- cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
218
- cos, sin = (
219
- cos[None, :, None, :],
220
- sin[None, :, None, :],
221
- ) # add batch and head dims for later broadcasting
222
-
223
- return cos, sin