torchtextclassifiers 0.0.1__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. torchtextclassifiers-1.0.0/PKG-INFO +87 -0
  2. torchtextclassifiers-1.0.0/README.md +61 -0
  3. {torchtextclassifiers-0.0.1 → torchtextclassifiers-1.0.0}/pyproject.toml +25 -17
  4. torchtextclassifiers-1.0.0/torchTextClassifiers/__init__.py +32 -0
  5. torchtextclassifiers-1.0.0/torchTextClassifiers/dataset/__init__.py +1 -0
  6. torchtextclassifiers-1.0.0/torchTextClassifiers/dataset/dataset.py +152 -0
  7. torchtextclassifiers-1.0.0/torchTextClassifiers/model/__init__.py +2 -0
  8. torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/__init__.py +12 -0
  9. torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/attention.py +126 -0
  10. torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/categorical_var_net.py +128 -0
  11. torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/classification_head.py +61 -0
  12. torchtextclassifiers-1.0.0/torchTextClassifiers/model/components/text_embedder.py +220 -0
  13. torchtextclassifiers-1.0.0/torchTextClassifiers/model/lightning.py +170 -0
  14. torchtextclassifiers-1.0.0/torchTextClassifiers/model/model.py +151 -0
  15. torchtextclassifiers-1.0.0/torchTextClassifiers/tokenizers/WordPiece.py +92 -0
  16. torchtextclassifiers-1.0.0/torchTextClassifiers/tokenizers/__init__.py +10 -0
  17. torchtextclassifiers-1.0.0/torchTextClassifiers/tokenizers/base.py +205 -0
  18. torchtextclassifiers-1.0.0/torchTextClassifiers/tokenizers/ngram.py +472 -0
  19. torchtextclassifiers-1.0.0/torchTextClassifiers/torchTextClassifiers.py +596 -0
  20. torchtextclassifiers-1.0.0/torchTextClassifiers/utilities/__init__.py +0 -0
  21. torchtextclassifiers-1.0.0/torchTextClassifiers/utilities/plot_explainability.py +184 -0
  22. torchtextclassifiers-0.0.1/PKG-INFO +0 -187
  23. torchtextclassifiers-0.0.1/README.md +0 -165
  24. torchtextclassifiers-0.0.1/torchTextClassifiers/__init__.py +0 -68
  25. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/base.py +0 -83
  26. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
  27. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/core.py +0 -269
  28. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/model.py +0 -752
  29. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
  30. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
  31. torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
  32. torchtextclassifiers-0.0.1/torchTextClassifiers/factories.py +0 -34
  33. torchtextclassifiers-0.0.1/torchTextClassifiers/torchTextClassifiers.py +0 -509
  34. torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/__init__.py +0 -3
  35. torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/checkers.py +0 -108
  36. torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/preprocess.py +0 -82
  37. torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/utils.py +0 -346
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.3
2
+ Name: torchtextclassifiers
3
+ Version: 1.0.0
4
+ Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
5
+ Keywords: fastText,text classification,NLP,automatic coding,deep learning
6
+ Author: Cédric Couralet, Meilame Tayebjee
7
+ Author-email: Cédric Couralet <cedric.couralet@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Requires-Dist: numpy>=1.26.4
12
+ Requires-Dist: pytorch-lightning>=2.4.0
13
+ Requires-Dist: unidecode ; extra == 'explainability'
14
+ Requires-Dist: nltk ; extra == 'explainability'
15
+ Requires-Dist: captum ; extra == 'explainability'
16
+ Requires-Dist: tokenizers>=0.22.1 ; extra == 'huggingface'
17
+ Requires-Dist: transformers>=4.57.1 ; extra == 'huggingface'
18
+ Requires-Dist: datasets>=4.3.0 ; extra == 'huggingface'
19
+ Requires-Dist: unidecode ; extra == 'preprocess'
20
+ Requires-Dist: nltk ; extra == 'preprocess'
21
+ Requires-Python: >=3.11
22
+ Provides-Extra: explainability
23
+ Provides-Extra: huggingface
24
+ Provides-Extra: preprocess
25
+ Description-Content-Type: text/markdown
26
+
27
+ # torchTextClassifiers
28
+
29
+ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://inseefrlab.github.io/torchTextClassifiers/)
30
+
31
+ A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
32
+
33
+ ## 🚀 Features
34
+
35
+ - **Complex input support**: Handle text data alongside categorical variables seamlessly.
36
+ - **Unified yet highly customizable**:
37
+ - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
38
+ - Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
39
+ - The `TextClassificationModel` class combines these components and can be extended for custom behavior.
40
+ - **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
41
+ - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
42
+ - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
43
+ - The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
44
+ - **Additional features**: explainability using Captum
45
+
46
+
47
+ ## 📦 Installation
48
+
49
+ ```bash
50
+ # Clone the repository
51
+ git clone https://github.com/InseeFrLab/torchTextClassifiers.git
52
+ cd torchtextClassifiers
53
+
54
+ # Install with uv (recommended)
55
+ uv sync
56
+
57
+ # Or install with pip
58
+ pip install -e .
59
+ ```
60
+
61
+ ## 📖 Documentation
62
+
63
+ Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
64
+ The documentation includes:
65
+ - **Getting Started**: Installation and quick start guide
66
+ - **Architecture**: Understanding the 3-layer design
67
+ - **Tutorials**: Step-by-step guides for different use cases
68
+ - **API Reference**: Complete API documentation
69
+
70
+ ## 📝 Usage
71
+
72
+ Checkout the [notebook](notebooks/example.ipynb) for a quick start.
73
+
74
+ ## 📚 Examples
75
+
76
+ See the [examples/](examples/) directory for:
77
+ - Basic text classification
78
+ - Multi-class classification
79
+ - Mixed features (text + categorical)
80
+ - Advanced training configurations
81
+ - Prediction and explainability
82
+
83
+ ## 📄 License
84
+
85
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
86
+
87
+
@@ -0,0 +1,61 @@
1
+ # torchTextClassifiers
2
+
3
+ [![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://inseefrlab.github.io/torchTextClassifiers/)
4
+
5
+ A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
6
+
7
+ ## 🚀 Features
8
+
9
+ - **Complex input support**: Handle text data alongside categorical variables seamlessly.
10
+ - **Unified yet highly customizable**:
11
+ - Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
12
+ - Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
13
+ - The `TextClassificationModel` class combines these components and can be extended for custom behavior.
14
+ - **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
15
+ - **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
16
+ - **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
17
+ - The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
18
+ - **Additional features**: explainability using Captum
19
+
20
+
21
+ ## 📦 Installation
22
+
23
+ ```bash
24
+ # Clone the repository
25
+ git clone https://github.com/InseeFrLab/torchTextClassifiers.git
26
+ cd torchtextClassifiers
27
+
28
+ # Install with uv (recommended)
29
+ uv sync
30
+
31
+ # Or install with pip
32
+ pip install -e .
33
+ ```
34
+
35
+ ## 📖 Documentation
36
+
37
+ Full documentation is available at: **https://inseefrlab.github.io/torchTextClassifiers/**
38
+ The documentation includes:
39
+ - **Getting Started**: Installation and quick start guide
40
+ - **Architecture**: Understanding the 3-layer design
41
+ - **Tutorials**: Step-by-step guides for different use cases
42
+ - **API Reference**: Complete API documentation
43
+
44
+ ## 📝 Usage
45
+
46
+ Checkout the [notebook](notebooks/example.ipynb) for a quick start.
47
+
48
+ ## 📚 Examples
49
+
50
+ See the [examples/](examples/) directory for:
51
+ - Basic text classification
52
+ - Multi-class classification
53
+ - Mixed features (text + categorical)
54
+ - Advanced training configurations
55
+ - Prediction and explainability
56
+
57
+ ## 📄 License
58
+
59
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
60
+
61
+
@@ -1,11 +1,9 @@
1
1
  [project]
2
2
  name = "torchtextclassifiers"
3
- description = "An implementation of the https://github.com/facebookresearch/fastText supervised learning algorithm for text classification using Pytorch."
3
+ description = "A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch."
4
4
  authors = [
5
- { name = "Tom Seimandi", email = "tom.seimandi@gmail.com" },
6
- { name = "Julien Pramil", email = "julien.pramil@insee.fr" },
7
- { name = "Meilame Tayebjee", email = "meilame.tayebjee@insee.fr" },
8
5
  { name = "Cédric Couralet", email = "cedric.couralet@insee.fr" },
6
+ { name = "Meilame Tayebjee", email = "meilame.tayebjee@insee.fr" },
9
7
  ]
10
8
  readme = "README.md"
11
9
  repository = "https://github.com/InseeFrLab/torchTextClassifiers"
@@ -20,35 +18,48 @@ dependencies = [
20
18
  "pytorch-lightning>=2.4.0",
21
19
  ]
22
20
  requires-python = ">=3.11"
23
- version="0.0.1"
21
+ version="1.0.0"
24
22
 
25
23
 
26
24
  [dependency-groups]
27
25
  dev = [
28
- "pytest >=8.1.1,<9",
26
+ "pytest >=9.0.1,<10",
29
27
  "pandas",
30
28
  "scikit-learn",
31
29
  "nltk",
32
30
  "unidecode",
33
31
  "captum",
34
- "pyarrow"
32
+ "pyarrow",
33
+ "pre-commit>=4.3.0",
34
+ "ruff>=0.14.3",
35
+ "ipywidgets>=8.1.8",
35
36
  ]
36
37
  docs = [
37
- "sphinx>=5.0.0",
38
- "sphinx-rtd-theme>=1.2.0",
39
- "sphinx-autodoc-typehints>=1.19.0",
40
- "sphinxcontrib-napoleon>=0.7",
38
+ "sphinx>=8.1.0",
39
+ "pydata-sphinx-theme>=0.16.0",
40
+ "sphinx-autodoc-typehints>=2.0.0",
41
41
  "sphinx-copybutton>=0.5.0",
42
- "myst-parser>=0.18.0",
43
- "sphinx-design>=0.3.0"
42
+ "myst-parser>=4.0.0",
43
+ "sphinx-design>=0.6.0",
44
+ "nbsphinx>=0.9.0",
45
+ "ipython>=8.0.0",
46
+ "pandoc>=2.0.0",
47
+ "linkify-it-py>=2.0.0",
48
+ "sphinxcontrib-images>=1.0.1"
44
49
  ]
45
50
 
46
51
  [project.optional-dependencies]
47
52
  explainability = ["unidecode", "nltk", "captum"]
48
53
  preprocess = ["unidecode", "nltk"]
54
+ huggingface = [
55
+ "tokenizers>=0.22.1",
56
+ "transformers>=4.57.1",
57
+ "datasets>=4.3.0",
58
+ ]
59
+
49
60
 
50
61
  [build-system]
51
- requires = ["uv_build>=0.8.3,<0.9.0"]
62
+ requires = ["uv_build>=0.9.3,<0.10.0"]
52
63
  build-backend = "uv_build"
53
64
 
54
65
  [tool.ruff]
@@ -58,6 +69,3 @@ line-length = 100
58
69
  [tool.uv.build-backend]
59
70
  module-name="torchTextClassifiers"
60
71
  module-root = ""
61
-
62
-
63
-
@@ -0,0 +1,32 @@
1
+ """torchTextClassifiers: A unified framework for text classification.
2
+
3
+ This package provides a generic, extensible framework for building and training
4
+ different types of text classifiers. It currently supports FastText classifiers
5
+ with a clean API for building, training, and inference.
6
+
7
+ Key Features:
8
+ - Unified API for different classifier types
9
+ - Built-in support for FastText classifiers
10
+ - PyTorch Lightning integration for training
11
+ - Extensible architecture for adding new classifier types
12
+ - Support for both text-only and mixed text/categorical features
13
+
14
+ """
15
+
16
+ from .torchTextClassifiers import (
17
+ ModelConfig as ModelConfig,
18
+ )
19
+ from .torchTextClassifiers import (
20
+ TrainingConfig as TrainingConfig,
21
+ )
22
+ from .torchTextClassifiers import (
23
+ torchTextClassifiers as torchTextClassifiers,
24
+ )
25
+
26
+ __all__ = [
27
+ "torchTextClassifiers",
28
+ "ModelConfig",
29
+ "TrainingConfig",
30
+ ]
31
+
32
+ __version__ = "1.0.0"
@@ -0,0 +1 @@
1
+ from .dataset import TextClassificationDataset as TextClassificationDataset
@@ -0,0 +1,152 @@
1
+ import logging
2
+ import os
3
+ from typing import List, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader, Dataset
8
+
9
+ from torchTextClassifiers.tokenizers import BaseTokenizer
10
+
11
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class TextClassificationDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ texts: List[str],
19
+ categorical_variables: Union[List[List[int]], np.array, None],
20
+ tokenizer: BaseTokenizer,
21
+ labels: Union[List[int], List[List[int]], np.array, None] = None,
22
+ ragged_multilabel: bool = False,
23
+ ):
24
+ self.categorical_variables = categorical_variables
25
+
26
+ self.texts = texts
27
+
28
+ if hasattr(tokenizer, "trained") and not tokenizer.trained:
29
+ raise RuntimeError(
30
+ f"Tokenizer {type(tokenizer)} must be trained before creating dataset."
31
+ )
32
+
33
+ self.tokenizer = tokenizer
34
+
35
+ self.texts = texts
36
+ self.tokenizer = tokenizer
37
+ self.labels = labels
38
+ self.ragged_multilabel = ragged_multilabel
39
+
40
+ if self.ragged_multilabel and self.labels is not None:
41
+ max_value = int(max(max(row) for row in labels if row))
42
+ self.num_classes = max_value + 1
43
+
44
+ if max_value == 1:
45
+ try:
46
+ labels = np.array(labels)
47
+ logger.critical(
48
+ """ragged_multilabel set to True but max label value is 1 and all samples have the same number of labels.
49
+ If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."""
50
+ )
51
+ except ValueError:
52
+ logger.warning(
53
+ "ragged_multilabel set to True but max label value is 1. If your labels are already one-hot encoded, set ragged_multilabel to False. Otherwise computations are likely to be wrong."
54
+ )
55
+
56
+ def __len__(self):
57
+ return len(self.texts)
58
+
59
+ def __getitem__(self, idx):
60
+ if self.labels is not None:
61
+ return (
62
+ str(self.texts[idx]),
63
+ (
64
+ self.categorical_variables[idx]
65
+ if self.categorical_variables is not None
66
+ else None
67
+ ),
68
+ self.labels[idx],
69
+ )
70
+ else:
71
+ return (
72
+ str(self.texts[idx]),
73
+ (
74
+ self.categorical_variables[idx]
75
+ if self.categorical_variables is not None
76
+ else None
77
+ ),
78
+ None,
79
+ )
80
+
81
+ def collate_fn(self, batch):
82
+ text, *categorical_vars, labels = zip(*batch)
83
+
84
+ if self.labels is not None:
85
+ if self.ragged_multilabel:
86
+ # Pad labels to the max length in the batch
87
+ labels_padded = torch.nn.utils.rnn.pad_sequence(
88
+ [torch.tensor(label) for label in labels],
89
+ batch_first=True,
90
+ padding_value=-1, # use impossible class
91
+ ).int()
92
+
93
+ labels_tensor = torch.zeros(labels_padded.size(0), 6).float()
94
+ mask = labels_padded != -1
95
+
96
+ batch_size = labels_padded.size(0)
97
+ rows = torch.arange(batch_size).unsqueeze(1).expand_as(labels_padded)[mask]
98
+ cols = labels_padded[mask]
99
+
100
+ labels_tensor[rows, cols] = 1
101
+
102
+ else:
103
+ labels_tensor = torch.tensor(labels)
104
+ else:
105
+ labels_tensor = None
106
+
107
+ tokenize_output = self.tokenizer.tokenize(list(text))
108
+
109
+ if self.categorical_variables is not None:
110
+ categorical_tensors = torch.stack(
111
+ [
112
+ torch.tensor(cat_var, dtype=torch.float32)
113
+ for cat_var in categorical_vars[
114
+ 0
115
+ ] # Access first element since zip returns tuple
116
+ ]
117
+ )
118
+ else:
119
+ categorical_tensors = None
120
+
121
+ return {
122
+ "input_ids": tokenize_output.input_ids,
123
+ "attention_mask": tokenize_output.attention_mask,
124
+ "categorical_vars": categorical_tensors,
125
+ "labels": labels_tensor,
126
+ }
127
+
128
+ def create_dataloader(
129
+ self,
130
+ batch_size: int,
131
+ shuffle: bool = False,
132
+ drop_last: bool = False,
133
+ num_workers: int = os.cpu_count() - 1,
134
+ pin_memory: bool = False,
135
+ persistent_workers: bool = True,
136
+ **kwargs,
137
+ ):
138
+ # persistent_workers requires num_workers > 0
139
+ if num_workers == 0:
140
+ persistent_workers = False
141
+
142
+ return DataLoader(
143
+ dataset=self,
144
+ batch_size=batch_size,
145
+ collate_fn=self.collate_fn,
146
+ shuffle=shuffle,
147
+ drop_last=drop_last,
148
+ pin_memory=pin_memory,
149
+ num_workers=num_workers,
150
+ persistent_workers=persistent_workers,
151
+ **kwargs,
152
+ )
@@ -0,0 +1,2 @@
1
+ from .lightning import TextClassificationModule as TextClassificationModule
2
+ from .model import TextClassificationModel as TextClassificationModel
@@ -0,0 +1,12 @@
1
+ from .attention import (
2
+ AttentionConfig as AttentionConfig,
3
+ )
4
+ from .categorical_var_net import (
5
+ CategoricalForwardType as CategoricalForwardType,
6
+ )
7
+ from .categorical_var_net import (
8
+ CategoricalVariableNet as CategoricalVariableNet,
9
+ )
10
+ from .classification_head import ClassificationHead as ClassificationHead
11
+ from .text_embedder import TextEmbedder as TextEmbedder
12
+ from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
@@ -0,0 +1,126 @@
1
+ """Largely inspired from Andrej Karpathy's nanochat, see here https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py"""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ ### Some utils used in text_embedder.py for the attention blocks ###
11
+
12
+
13
+ def apply_rotary_emb(x, cos, sin):
14
+ assert x.ndim == 4 # multihead attention
15
+
16
+ d = x.shape[3] // 2
17
+ x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
18
+ y1 = x1 * cos + x2 * sin # rotate pairs of dims
19
+ y2 = x1 * (-sin) + x2 * cos
20
+ out = torch.cat([y1, y2], 3) # re-assemble
21
+ out = out.to(x.dtype) # ensure input/output dtypes match
22
+ return out
23
+
24
+
25
+ def norm(x):
26
+ # Purely functional rmsnorm with no learnable params
27
+ return F.rms_norm(x, (x.size(-1),))
28
+
29
+
30
+ #### Config #####
31
+ @dataclass
32
+ class AttentionConfig:
33
+ n_layers: int
34
+ n_head: int
35
+ n_kv_head: int
36
+ sequence_len: Optional[int] = None
37
+ positional_encoding: bool = True
38
+ aggregation_method: str = "mean" # or 'last', or 'first'
39
+
40
+
41
+ #### Attention Block #####
42
+
43
+ # Composed of SelfAttentionLayer and MLP with residual connections
44
+
45
+
46
+ class Block(nn.Module):
47
+ def __init__(self, config: AttentionConfig, layer_idx: int):
48
+ super().__init__()
49
+
50
+ self.layer_idx = layer_idx
51
+ self.attn = SelfAttentionLayer(config, layer_idx)
52
+ self.mlp = MLP(config)
53
+
54
+ def forward(self, x, cos_sin):
55
+ x = x + self.attn(norm(x), cos_sin)
56
+ x = x + self.mlp(norm(x))
57
+ return x
58
+
59
+
60
+ ##### Components of the Block #####
61
+
62
+
63
+ class SelfAttentionLayer(nn.Module):
64
+ def __init__(self, config: AttentionConfig, layer_idx):
65
+ super().__init__()
66
+ self.layer_idx = layer_idx
67
+ self.n_head = config.n_head
68
+ self.n_kv_head = config.n_kv_head
69
+ self.enable_gqa = (
70
+ self.n_head != self.n_kv_head
71
+ ) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
72
+ self.n_embd = config.n_embd
73
+ self.head_dim = self.n_embd // self.n_head
74
+ assert self.n_embd % self.n_head == 0
75
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
76
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
77
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
78
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
79
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
80
+
81
+ self.apply_positional_encoding = config.positional_encoding
82
+
83
+ def forward(self, x, cos_sin=None):
84
+ B, T, C = x.size()
85
+
86
+ # Project the input to get queries, keys, and values
87
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
88
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
89
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
90
+
91
+ if self.apply_positional_encoding:
92
+ assert cos_sin is not None, "Rotary embeddings require precomputed cos/sin tensors"
93
+ cos, sin = cos_sin
94
+ q, k = (
95
+ apply_rotary_emb(q, cos, sin),
96
+ apply_rotary_emb(k, cos, sin),
97
+ ) # QK rotary embedding
98
+
99
+ q, k = norm(q), norm(k) # QK norm
100
+ q, k, v = (
101
+ q.transpose(1, 2),
102
+ k.transpose(1, 2),
103
+ v.transpose(1, 2),
104
+ ) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
105
+
106
+ # is_causal=False for non-autoregressive models (BERT-like)
107
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa)
108
+
109
+ # Re-assemble the heads side by side and project back to residual stream
110
+ y = y.transpose(1, 2).contiguous().view(B, T, -1)
111
+ y = self.c_proj(y)
112
+
113
+ return y
114
+
115
+
116
+ class MLP(nn.Module):
117
+ def __init__(self, config):
118
+ super().__init__()
119
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
120
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
121
+
122
+ def forward(self, x):
123
+ x = self.c_fc(x)
124
+ x = F.relu(x).square()
125
+ x = self.c_proj(x)
126
+ return x
@@ -0,0 +1,128 @@
1
+ from enum import Enum
2
+ from typing import List, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+
8
+ class CategoricalForwardType(Enum):
9
+ SUM_TO_TEXT = "EMBEDDING_SUM_TO_TEXT"
10
+ AVERAGE_AND_CONCAT = "EMBEDDING_AVERAGE_AND_CONCAT"
11
+ CONCATENATE_ALL = "EMBEDDING_CONCATENATE_ALL"
12
+
13
+
14
+ class CategoricalVariableNet(nn.Module):
15
+ def __init__(
16
+ self,
17
+ categorical_vocabulary_sizes: List[int],
18
+ categorical_embedding_dims: Optional[Union[List[int], int]] = None,
19
+ text_embedding_dim: Optional[int] = None,
20
+ ):
21
+ super().__init__()
22
+
23
+ self.categorical_vocabulary_sizes = categorical_vocabulary_sizes
24
+ self.categorical_embedding_dims = categorical_embedding_dims
25
+ self.text_embedding_dim = text_embedding_dim
26
+
27
+ self._validate_categorical_inputs()
28
+ assert isinstance(
29
+ self.forward_type, CategoricalForwardType
30
+ ), "forward_type must be set after validation"
31
+ assert isinstance(self.output_dim, int), "output_dim must be set as int after validation"
32
+
33
+ self.categorical_embedding_layers = {}
34
+
35
+ for var_idx, num_rows in enumerate(self.categorical_vocabulary_sizes):
36
+ emb_layer = nn.Embedding(
37
+ num_embeddings=num_rows,
38
+ embedding_dim=self.categorical_embedding_dims[var_idx],
39
+ )
40
+ self.categorical_embedding_layers[var_idx] = emb_layer
41
+ setattr(self, f"categorical_embedding_{var_idx}", emb_layer)
42
+
43
+ def forward(self, categorical_vars_tensor: torch.Tensor) -> torch.Tensor:
44
+ cat_embeds = self._get_cat_embeds(categorical_vars_tensor)
45
+ if self.forward_type == CategoricalForwardType.SUM_TO_TEXT:
46
+ x_combined = torch.stack(cat_embeds, dim=0).sum(dim=0) # (bs, text_embed_dim)
47
+ elif self.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT:
48
+ x_combined = torch.stack(cat_embeds, dim=0).mean(dim=0) # (bs, embed_dim)
49
+ elif self.forward_type == CategoricalForwardType.CONCATENATE_ALL:
50
+ x_combined = torch.cat(cat_embeds, dim=1) # (bs, sum of all cat embed dims)
51
+ else:
52
+ raise ValueError(f"Unknown forward type: {self.forward_type}")
53
+
54
+ assert (
55
+ x_combined.dim() == 2
56
+ ), "Output combined tensor must be 2-dimensional (batch_size, embed_dim)"
57
+ assert x_combined.size(1) == self.output_dim
58
+
59
+ return x_combined
60
+
61
+ def _get_cat_embeds(self, categorical_vars_tensor: torch.Tensor):
62
+ if categorical_vars_tensor.dtype != torch.long:
63
+ categorical_vars_tensor = categorical_vars_tensor.to(torch.long)
64
+ cat_embeds = []
65
+
66
+ for i, embed_layer in self.categorical_embedding_layers.items():
67
+ cat_var_tensor = categorical_vars_tensor[:, i]
68
+
69
+ # Check if categorical values are within valid range
70
+ vocab_size = embed_layer.num_embeddings
71
+ max_val = cat_var_tensor.max().item()
72
+ min_val = cat_var_tensor.min().item()
73
+
74
+ if max_val >= vocab_size or min_val < 0:
75
+ raise ValueError(
76
+ f"Categorical feature {i}: values range [{min_val}, {max_val}] exceed vocabulary size {vocab_size}."
77
+ )
78
+
79
+ cat_embed = embed_layer(cat_var_tensor)
80
+ if cat_embed.dim() > 2:
81
+ cat_embed = cat_embed.squeeze(1)
82
+ cat_embeds.append(cat_embed)
83
+
84
+ return cat_embeds
85
+
86
+ def _validate_categorical_inputs(self):
87
+ categorical_vocabulary_sizes = self.categorical_vocabulary_sizes
88
+ categorical_embedding_dims = self.categorical_embedding_dims
89
+
90
+ if not isinstance(categorical_vocabulary_sizes, list):
91
+ raise TypeError("categorical_vocabulary_sizes must be a list of int")
92
+
93
+ if isinstance(categorical_embedding_dims, list):
94
+ if len(categorical_vocabulary_sizes) != len(categorical_embedding_dims):
95
+ raise ValueError(
96
+ "Categorical vocabulary sizes and their embedding dimensions must have the same length"
97
+ )
98
+
99
+ num_categorical_features = len(categorical_vocabulary_sizes)
100
+
101
+ # "Transform" embedding dims into a suitable list, or stay None
102
+ if categorical_embedding_dims is not None:
103
+ if isinstance(categorical_embedding_dims, int):
104
+ self.forward_type = CategoricalForwardType.AVERAGE_AND_CONCAT
105
+ self.output_dim = categorical_embedding_dims
106
+ categorical_embedding_dims = [categorical_embedding_dims] * num_categorical_features
107
+
108
+ elif isinstance(categorical_embedding_dims, list):
109
+ self.forward_type = CategoricalForwardType.CONCATENATE_ALL
110
+ self.output_dim = sum(categorical_embedding_dims)
111
+ else:
112
+ raise TypeError("categorical_embedding_dims must be an int, a list of int or None")
113
+ else:
114
+ if self.text_embedding_dim is None:
115
+ raise ValueError(
116
+ "If categorical_embedding_dims is None, text_embedding_dim must be provided"
117
+ )
118
+ self.forward_type = CategoricalForwardType.SUM_TO_TEXT
119
+ self.output_dim = self.text_embedding_dim
120
+ categorical_embedding_dims = [self.text_embedding_dim] * num_categorical_features
121
+
122
+ assert (
123
+ isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None
124
+ ), "categorical_embedding_dims must be a list of int at this point"
125
+
126
+ self.categorical_vocabulary_sizes = categorical_vocabulary_sizes
127
+ self.categorical_embedding_dims = categorical_embedding_dims
128
+ self.num_categorical_features = num_categorical_features