torchtextclassifiers 0.0.1__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchtextclassifiers-0.1.0/PKG-INFO +73 -0
- torchtextclassifiers-0.1.0/README.md +47 -0
- {torchtextclassifiers-0.0.1 → torchtextclassifiers-0.1.0}/pyproject.toml +14 -10
- torchtextclassifiers-0.1.0/torchTextClassifiers/__init__.py +32 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/dataset/__init__.py +1 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/dataset/dataset.py +114 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/__init__.py +2 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/__init__.py +12 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/attention.py +126 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/categorical_var_net.py +128 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/classification_head.py +43 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/components/text_embedder.py +220 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/lightning.py +166 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/model/model.py +151 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/tokenizers/WordPiece.py +92 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/tokenizers/__init__.py +10 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/tokenizers/base.py +205 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/tokenizers/ngram.py +472 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/torchTextClassifiers.py +567 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/utilities/__init__.py +0 -0
- torchtextclassifiers-0.1.0/torchTextClassifiers/utilities/plot_explainability.py +184 -0
- torchtextclassifiers-0.0.1/PKG-INFO +0 -187
- torchtextclassifiers-0.0.1/README.md +0 -165
- torchtextclassifiers-0.0.1/torchTextClassifiers/__init__.py +0 -68
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/base.py +0 -83
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/__init__.py +0 -25
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/core.py +0 -269
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/model.py +0 -752
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/tokenizer.py +0 -346
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/fasttext/wrapper.py +0 -216
- torchtextclassifiers-0.0.1/torchTextClassifiers/classifiers/simple_text_classifier.py +0 -191
- torchtextclassifiers-0.0.1/torchTextClassifiers/factories.py +0 -34
- torchtextclassifiers-0.0.1/torchTextClassifiers/torchTextClassifiers.py +0 -509
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/__init__.py +0 -3
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/checkers.py +0 -108
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/preprocess.py +0 -82
- torchtextclassifiers-0.0.1/torchTextClassifiers/utilities/utils.py +0 -346
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: torchtextclassifiers
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch.
|
|
5
|
+
Keywords: fastText,text classification,NLP,automatic coding,deep learning
|
|
6
|
+
Author: Cédric Couralet, Meilame Tayebjee
|
|
7
|
+
Author-email: Cédric Couralet <cedric.couralet@insee.fr>, Meilame Tayebjee <meilame.tayebjee@insee.fr>
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Dist: numpy>=1.26.4
|
|
12
|
+
Requires-Dist: pytorch-lightning>=2.4.0
|
|
13
|
+
Requires-Dist: unidecode ; extra == 'explainability'
|
|
14
|
+
Requires-Dist: nltk ; extra == 'explainability'
|
|
15
|
+
Requires-Dist: captum ; extra == 'explainability'
|
|
16
|
+
Requires-Dist: tokenizers>=0.22.1 ; extra == 'huggingface'
|
|
17
|
+
Requires-Dist: transformers>=4.57.1 ; extra == 'huggingface'
|
|
18
|
+
Requires-Dist: datasets>=4.3.0 ; extra == 'huggingface'
|
|
19
|
+
Requires-Dist: unidecode ; extra == 'preprocess'
|
|
20
|
+
Requires-Dist: nltk ; extra == 'preprocess'
|
|
21
|
+
Requires-Python: >=3.11
|
|
22
|
+
Provides-Extra: explainability
|
|
23
|
+
Provides-Extra: huggingface
|
|
24
|
+
Provides-Extra: preprocess
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# torchTextClassifiers
|
|
28
|
+
|
|
29
|
+
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
30
|
+
|
|
31
|
+
## 🚀 Features
|
|
32
|
+
|
|
33
|
+
- **Mixed input support**: Handle text data alongside categorical variables seamlessly.
|
|
34
|
+
- **Unified yet highly customizable**:
|
|
35
|
+
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
36
|
+
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
37
|
+
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
38
|
+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
39
|
+
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
40
|
+
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
41
|
+
- **Additional features**: explainability using Captum
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
## 📦 Installation
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
# Clone the repository
|
|
48
|
+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
49
|
+
cd torchtextClassifiers
|
|
50
|
+
|
|
51
|
+
# Install with uv (recommended)
|
|
52
|
+
uv sync
|
|
53
|
+
|
|
54
|
+
# Or install with pip
|
|
55
|
+
pip install -e .
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
## 📝 Usage
|
|
59
|
+
|
|
60
|
+
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
61
|
+
|
|
62
|
+
## 📚 Examples
|
|
63
|
+
|
|
64
|
+
See the [examples/](examples/) directory for:
|
|
65
|
+
- Basic text classification
|
|
66
|
+
- Multi-class classification
|
|
67
|
+
- Mixed features (text + categorical)
|
|
68
|
+
- Advanced training configurations
|
|
69
|
+
- Prediction and explainability
|
|
70
|
+
|
|
71
|
+
## 📄 License
|
|
72
|
+
|
|
73
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# torchTextClassifiers
|
|
2
|
+
|
|
3
|
+
A unified, extensible framework for text classification with categorical variables built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
|
|
4
|
+
|
|
5
|
+
## 🚀 Features
|
|
6
|
+
|
|
7
|
+
- **Mixed input support**: Handle text data alongside categorical variables seamlessly.
|
|
8
|
+
- **Unified yet highly customizable**:
|
|
9
|
+
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
|
|
10
|
+
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
|
|
11
|
+
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
|
|
12
|
+
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
|
|
13
|
+
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
|
|
14
|
+
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
|
|
15
|
+
- **Additional features**: explainability using Captum
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
## 📦 Installation
|
|
19
|
+
|
|
20
|
+
```bash
|
|
21
|
+
# Clone the repository
|
|
22
|
+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
|
|
23
|
+
cd torchtextClassifiers
|
|
24
|
+
|
|
25
|
+
# Install with uv (recommended)
|
|
26
|
+
uv sync
|
|
27
|
+
|
|
28
|
+
# Or install with pip
|
|
29
|
+
pip install -e .
|
|
30
|
+
```
|
|
31
|
+
|
|
32
|
+
## 📝 Usage
|
|
33
|
+
|
|
34
|
+
Checkout the [notebook](notebooks/example.ipynb) for a quick start.
|
|
35
|
+
|
|
36
|
+
## 📚 Examples
|
|
37
|
+
|
|
38
|
+
See the [examples/](examples/) directory for:
|
|
39
|
+
- Basic text classification
|
|
40
|
+
- Multi-class classification
|
|
41
|
+
- Mixed features (text + categorical)
|
|
42
|
+
- Advanced training configurations
|
|
43
|
+
- Prediction and explainability
|
|
44
|
+
|
|
45
|
+
## 📄 License
|
|
46
|
+
|
|
47
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
@@ -1,11 +1,9 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "torchtextclassifiers"
|
|
3
|
-
description = "
|
|
3
|
+
description = "A text classification toolkit to easily build, train and evaluate deep learning text classifiers using PyTorch."
|
|
4
4
|
authors = [
|
|
5
|
-
{ name = "Tom Seimandi", email = "tom.seimandi@gmail.com" },
|
|
6
|
-
{ name = "Julien Pramil", email = "julien.pramil@insee.fr" },
|
|
7
|
-
{ name = "Meilame Tayebjee", email = "meilame.tayebjee@insee.fr" },
|
|
8
5
|
{ name = "Cédric Couralet", email = "cedric.couralet@insee.fr" },
|
|
6
|
+
{ name = "Meilame Tayebjee", email = "meilame.tayebjee@insee.fr" },
|
|
9
7
|
]
|
|
10
8
|
readme = "README.md"
|
|
11
9
|
repository = "https://github.com/InseeFrLab/torchTextClassifiers"
|
|
@@ -20,7 +18,7 @@ dependencies = [
|
|
|
20
18
|
"pytorch-lightning>=2.4.0",
|
|
21
19
|
]
|
|
22
20
|
requires-python = ">=3.11"
|
|
23
|
-
version="0.0
|
|
21
|
+
version="0.1.0"
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
[dependency-groups]
|
|
@@ -31,7 +29,10 @@ dev = [
|
|
|
31
29
|
"nltk",
|
|
32
30
|
"unidecode",
|
|
33
31
|
"captum",
|
|
34
|
-
"pyarrow"
|
|
32
|
+
"pyarrow",
|
|
33
|
+
"pre-commit>=4.3.0",
|
|
34
|
+
"ruff>=0.14.3",
|
|
35
|
+
"ipywidgets>=8.1.8",
|
|
35
36
|
]
|
|
36
37
|
docs = [
|
|
37
38
|
"sphinx>=5.0.0",
|
|
@@ -46,9 +47,15 @@ docs = [
|
|
|
46
47
|
[project.optional-dependencies]
|
|
47
48
|
explainability = ["unidecode", "nltk", "captum"]
|
|
48
49
|
preprocess = ["unidecode", "nltk"]
|
|
50
|
+
huggingface = [
|
|
51
|
+
"tokenizers>=0.22.1",
|
|
52
|
+
"transformers>=4.57.1",
|
|
53
|
+
"datasets>=4.3.0",
|
|
54
|
+
]
|
|
55
|
+
|
|
49
56
|
|
|
50
57
|
[build-system]
|
|
51
|
-
requires = ["uv_build>=0.
|
|
58
|
+
requires = ["uv_build>=0.9.3,<0.10.0"]
|
|
52
59
|
build-backend = "uv_build"
|
|
53
60
|
|
|
54
61
|
[tool.ruff]
|
|
@@ -58,6 +65,3 @@ line-length = 100
|
|
|
58
65
|
[tool.uv.build-backend]
|
|
59
66
|
module-name="torchTextClassifiers"
|
|
60
67
|
module-root = ""
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
"""torchTextClassifiers: A unified framework for text classification.
|
|
2
|
+
|
|
3
|
+
This package provides a generic, extensible framework for building and training
|
|
4
|
+
different types of text classifiers. It currently supports FastText classifiers
|
|
5
|
+
with a clean API for building, training, and inference.
|
|
6
|
+
|
|
7
|
+
Key Features:
|
|
8
|
+
- Unified API for different classifier types
|
|
9
|
+
- Built-in support for FastText classifiers
|
|
10
|
+
- PyTorch Lightning integration for training
|
|
11
|
+
- Extensible architecture for adding new classifier types
|
|
12
|
+
- Support for both text-only and mixed text/categorical features
|
|
13
|
+
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from .torchTextClassifiers import (
|
|
17
|
+
ModelConfig as ModelConfig,
|
|
18
|
+
)
|
|
19
|
+
from .torchTextClassifiers import (
|
|
20
|
+
TrainingConfig as TrainingConfig,
|
|
21
|
+
)
|
|
22
|
+
from .torchTextClassifiers import (
|
|
23
|
+
torchTextClassifiers as torchTextClassifiers,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"torchTextClassifiers",
|
|
28
|
+
"ModelConfig",
|
|
29
|
+
"TrainingConfig",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
__version__ = "1.0.0"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .dataset import TextClassificationDataset as TextClassificationDataset
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import DataLoader, Dataset
|
|
7
|
+
|
|
8
|
+
from torchTextClassifiers.tokenizers import BaseTokenizer
|
|
9
|
+
|
|
10
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextClassificationDataset(Dataset):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
texts: List[str],
|
|
17
|
+
categorical_variables: Union[List[List[int]], np.array, None],
|
|
18
|
+
tokenizer: BaseTokenizer,
|
|
19
|
+
labels: Union[List[int], None] = None,
|
|
20
|
+
):
|
|
21
|
+
self.categorical_variables = categorical_variables
|
|
22
|
+
|
|
23
|
+
self.texts = texts
|
|
24
|
+
|
|
25
|
+
if hasattr(tokenizer, "trained") and not tokenizer.trained:
|
|
26
|
+
raise RuntimeError(
|
|
27
|
+
f"Tokenizer {type(tokenizer)} must be trained before creating dataset."
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
self.tokenizer = tokenizer
|
|
31
|
+
|
|
32
|
+
self.texts = texts
|
|
33
|
+
self.tokenizer = tokenizer
|
|
34
|
+
self.labels = labels
|
|
35
|
+
|
|
36
|
+
def __len__(self):
|
|
37
|
+
return len(self.texts)
|
|
38
|
+
|
|
39
|
+
def __getitem__(self, idx):
|
|
40
|
+
if self.labels is not None:
|
|
41
|
+
return (
|
|
42
|
+
str(self.texts[idx]),
|
|
43
|
+
(
|
|
44
|
+
self.categorical_variables[idx]
|
|
45
|
+
if self.categorical_variables is not None
|
|
46
|
+
else None
|
|
47
|
+
),
|
|
48
|
+
self.labels[idx],
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
return (
|
|
52
|
+
str(self.texts[idx]),
|
|
53
|
+
(
|
|
54
|
+
self.categorical_variables[idx]
|
|
55
|
+
if self.categorical_variables is not None
|
|
56
|
+
else None
|
|
57
|
+
),
|
|
58
|
+
None,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def collate_fn(self, batch):
|
|
62
|
+
text, *categorical_vars, y = zip(*batch)
|
|
63
|
+
|
|
64
|
+
if self.labels is not None:
|
|
65
|
+
labels_tensor = torch.tensor(y, dtype=torch.long)
|
|
66
|
+
else:
|
|
67
|
+
labels_tensor = None
|
|
68
|
+
|
|
69
|
+
tokenize_output = self.tokenizer.tokenize(list(text))
|
|
70
|
+
|
|
71
|
+
if self.categorical_variables is not None:
|
|
72
|
+
categorical_tensors = torch.stack(
|
|
73
|
+
[
|
|
74
|
+
torch.tensor(cat_var, dtype=torch.float32)
|
|
75
|
+
for cat_var in categorical_vars[
|
|
76
|
+
0
|
|
77
|
+
] # Access first element since zip returns tuple
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
categorical_tensors = None
|
|
82
|
+
|
|
83
|
+
return {
|
|
84
|
+
"input_ids": tokenize_output.input_ids,
|
|
85
|
+
"attention_mask": tokenize_output.attention_mask,
|
|
86
|
+
"categorical_vars": categorical_tensors,
|
|
87
|
+
"labels": labels_tensor,
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
def create_dataloader(
|
|
91
|
+
self,
|
|
92
|
+
batch_size: int,
|
|
93
|
+
shuffle: bool = False,
|
|
94
|
+
drop_last: bool = False,
|
|
95
|
+
num_workers: int = os.cpu_count() - 1,
|
|
96
|
+
pin_memory: bool = False,
|
|
97
|
+
persistent_workers: bool = True,
|
|
98
|
+
**kwargs,
|
|
99
|
+
):
|
|
100
|
+
# persistent_workers requires num_workers > 0
|
|
101
|
+
if num_workers == 0:
|
|
102
|
+
persistent_workers = False
|
|
103
|
+
|
|
104
|
+
return DataLoader(
|
|
105
|
+
dataset=self,
|
|
106
|
+
batch_size=batch_size,
|
|
107
|
+
collate_fn=self.collate_fn,
|
|
108
|
+
shuffle=shuffle,
|
|
109
|
+
drop_last=drop_last,
|
|
110
|
+
pin_memory=pin_memory,
|
|
111
|
+
num_workers=num_workers,
|
|
112
|
+
persistent_workers=persistent_workers,
|
|
113
|
+
**kwargs,
|
|
114
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .attention import (
|
|
2
|
+
AttentionConfig as AttentionConfig,
|
|
3
|
+
)
|
|
4
|
+
from .categorical_var_net import (
|
|
5
|
+
CategoricalForwardType as CategoricalForwardType,
|
|
6
|
+
)
|
|
7
|
+
from .categorical_var_net import (
|
|
8
|
+
CategoricalVariableNet as CategoricalVariableNet,
|
|
9
|
+
)
|
|
10
|
+
from .classification_head import ClassificationHead as ClassificationHead
|
|
11
|
+
from .text_embedder import TextEmbedder as TextEmbedder
|
|
12
|
+
from .text_embedder import TextEmbedderConfig as TextEmbedderConfig
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Largely inspired from Andrej Karpathy's nanochat, see here https://github.com/karpathy/nanochat/blob/master/nanochat/gpt.py"""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
### Some utils used in text_embedder.py for the attention blocks ###
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def apply_rotary_emb(x, cos, sin):
|
|
14
|
+
assert x.ndim == 4 # multihead attention
|
|
15
|
+
|
|
16
|
+
d = x.shape[3] // 2
|
|
17
|
+
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
|
18
|
+
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
|
19
|
+
y2 = x1 * (-sin) + x2 * cos
|
|
20
|
+
out = torch.cat([y1, y2], 3) # re-assemble
|
|
21
|
+
out = out.to(x.dtype) # ensure input/output dtypes match
|
|
22
|
+
return out
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def norm(x):
|
|
26
|
+
# Purely functional rmsnorm with no learnable params
|
|
27
|
+
return F.rms_norm(x, (x.size(-1),))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
#### Config #####
|
|
31
|
+
@dataclass
|
|
32
|
+
class AttentionConfig:
|
|
33
|
+
n_layers: int
|
|
34
|
+
n_head: int
|
|
35
|
+
n_kv_head: int
|
|
36
|
+
sequence_len: Optional[int] = None
|
|
37
|
+
positional_encoding: bool = True
|
|
38
|
+
aggregation_method: str = "mean" # or 'last', or 'first'
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
#### Attention Block #####
|
|
42
|
+
|
|
43
|
+
# Composed of SelfAttentionLayer and MLP with residual connections
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Block(nn.Module):
|
|
47
|
+
def __init__(self, config: AttentionConfig, layer_idx: int):
|
|
48
|
+
super().__init__()
|
|
49
|
+
|
|
50
|
+
self.layer_idx = layer_idx
|
|
51
|
+
self.attn = SelfAttentionLayer(config, layer_idx)
|
|
52
|
+
self.mlp = MLP(config)
|
|
53
|
+
|
|
54
|
+
def forward(self, x, cos_sin):
|
|
55
|
+
x = x + self.attn(norm(x), cos_sin)
|
|
56
|
+
x = x + self.mlp(norm(x))
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
##### Components of the Block #####
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class SelfAttentionLayer(nn.Module):
|
|
64
|
+
def __init__(self, config: AttentionConfig, layer_idx):
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.layer_idx = layer_idx
|
|
67
|
+
self.n_head = config.n_head
|
|
68
|
+
self.n_kv_head = config.n_kv_head
|
|
69
|
+
self.enable_gqa = (
|
|
70
|
+
self.n_head != self.n_kv_head
|
|
71
|
+
) # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
|
72
|
+
self.n_embd = config.n_embd
|
|
73
|
+
self.head_dim = self.n_embd // self.n_head
|
|
74
|
+
assert self.n_embd % self.n_head == 0
|
|
75
|
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
|
76
|
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
|
77
|
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
78
|
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
|
79
|
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
|
80
|
+
|
|
81
|
+
self.apply_positional_encoding = config.positional_encoding
|
|
82
|
+
|
|
83
|
+
def forward(self, x, cos_sin=None):
|
|
84
|
+
B, T, C = x.size()
|
|
85
|
+
|
|
86
|
+
# Project the input to get queries, keys, and values
|
|
87
|
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
|
88
|
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
89
|
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
|
90
|
+
|
|
91
|
+
if self.apply_positional_encoding:
|
|
92
|
+
assert cos_sin is not None, "Rotary embeddings require precomputed cos/sin tensors"
|
|
93
|
+
cos, sin = cos_sin
|
|
94
|
+
q, k = (
|
|
95
|
+
apply_rotary_emb(q, cos, sin),
|
|
96
|
+
apply_rotary_emb(k, cos, sin),
|
|
97
|
+
) # QK rotary embedding
|
|
98
|
+
|
|
99
|
+
q, k = norm(q), norm(k) # QK norm
|
|
100
|
+
q, k, v = (
|
|
101
|
+
q.transpose(1, 2),
|
|
102
|
+
k.transpose(1, 2),
|
|
103
|
+
v.transpose(1, 2),
|
|
104
|
+
) # make head be batch dim, i.e. (B, T, H, D) -> (B, H, T, D)
|
|
105
|
+
|
|
106
|
+
# is_causal=False for non-autoregressive models (BERT-like)
|
|
107
|
+
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=self.enable_gqa)
|
|
108
|
+
|
|
109
|
+
# Re-assemble the heads side by side and project back to residual stream
|
|
110
|
+
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
|
111
|
+
y = self.c_proj(y)
|
|
112
|
+
|
|
113
|
+
return y
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class MLP(nn.Module):
|
|
117
|
+
def __init__(self, config):
|
|
118
|
+
super().__init__()
|
|
119
|
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
|
120
|
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
|
121
|
+
|
|
122
|
+
def forward(self, x):
|
|
123
|
+
x = self.c_fc(x)
|
|
124
|
+
x = F.relu(x).square()
|
|
125
|
+
x = self.c_proj(x)
|
|
126
|
+
return x
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import List, Optional, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CategoricalForwardType(Enum):
|
|
9
|
+
SUM_TO_TEXT = "EMBEDDING_SUM_TO_TEXT"
|
|
10
|
+
AVERAGE_AND_CONCAT = "EMBEDDING_AVERAGE_AND_CONCAT"
|
|
11
|
+
CONCATENATE_ALL = "EMBEDDING_CONCATENATE_ALL"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class CategoricalVariableNet(nn.Module):
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
categorical_vocabulary_sizes: List[int],
|
|
18
|
+
categorical_embedding_dims: Optional[Union[List[int], int]] = None,
|
|
19
|
+
text_embedding_dim: Optional[int] = None,
|
|
20
|
+
):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
self.categorical_vocabulary_sizes = categorical_vocabulary_sizes
|
|
24
|
+
self.categorical_embedding_dims = categorical_embedding_dims
|
|
25
|
+
self.text_embedding_dim = text_embedding_dim
|
|
26
|
+
|
|
27
|
+
self._validate_categorical_inputs()
|
|
28
|
+
assert isinstance(
|
|
29
|
+
self.forward_type, CategoricalForwardType
|
|
30
|
+
), "forward_type must be set after validation"
|
|
31
|
+
assert isinstance(self.output_dim, int), "output_dim must be set as int after validation"
|
|
32
|
+
|
|
33
|
+
self.categorical_embedding_layers = {}
|
|
34
|
+
|
|
35
|
+
for var_idx, num_rows in enumerate(self.categorical_vocabulary_sizes):
|
|
36
|
+
emb_layer = nn.Embedding(
|
|
37
|
+
num_embeddings=num_rows,
|
|
38
|
+
embedding_dim=self.categorical_embedding_dims[var_idx],
|
|
39
|
+
)
|
|
40
|
+
self.categorical_embedding_layers[var_idx] = emb_layer
|
|
41
|
+
setattr(self, f"categorical_embedding_{var_idx}", emb_layer)
|
|
42
|
+
|
|
43
|
+
def forward(self, categorical_vars_tensor: torch.Tensor) -> torch.Tensor:
|
|
44
|
+
cat_embeds = self._get_cat_embeds(categorical_vars_tensor)
|
|
45
|
+
if self.forward_type == CategoricalForwardType.SUM_TO_TEXT:
|
|
46
|
+
x_combined = torch.stack(cat_embeds, dim=0).sum(dim=0) # (bs, text_embed_dim)
|
|
47
|
+
elif self.forward_type == CategoricalForwardType.AVERAGE_AND_CONCAT:
|
|
48
|
+
x_combined = torch.stack(cat_embeds, dim=0).mean(dim=0) # (bs, embed_dim)
|
|
49
|
+
elif self.forward_type == CategoricalForwardType.CONCATENATE_ALL:
|
|
50
|
+
x_combined = torch.cat(cat_embeds, dim=1) # (bs, sum of all cat embed dims)
|
|
51
|
+
else:
|
|
52
|
+
raise ValueError(f"Unknown forward type: {self.forward_type}")
|
|
53
|
+
|
|
54
|
+
assert (
|
|
55
|
+
x_combined.dim() == 2
|
|
56
|
+
), "Output combined tensor must be 2-dimensional (batch_size, embed_dim)"
|
|
57
|
+
assert x_combined.size(1) == self.output_dim
|
|
58
|
+
|
|
59
|
+
return x_combined
|
|
60
|
+
|
|
61
|
+
def _get_cat_embeds(self, categorical_vars_tensor: torch.Tensor):
|
|
62
|
+
if categorical_vars_tensor.dtype != torch.long:
|
|
63
|
+
categorical_vars_tensor = categorical_vars_tensor.to(torch.long)
|
|
64
|
+
cat_embeds = []
|
|
65
|
+
|
|
66
|
+
for i, embed_layer in self.categorical_embedding_layers.items():
|
|
67
|
+
cat_var_tensor = categorical_vars_tensor[:, i]
|
|
68
|
+
|
|
69
|
+
# Check if categorical values are within valid range
|
|
70
|
+
vocab_size = embed_layer.num_embeddings
|
|
71
|
+
max_val = cat_var_tensor.max().item()
|
|
72
|
+
min_val = cat_var_tensor.min().item()
|
|
73
|
+
|
|
74
|
+
if max_val >= vocab_size or min_val < 0:
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Categorical feature {i}: values range [{min_val}, {max_val}] exceed vocabulary size {vocab_size}."
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
cat_embed = embed_layer(cat_var_tensor)
|
|
80
|
+
if cat_embed.dim() > 2:
|
|
81
|
+
cat_embed = cat_embed.squeeze(1)
|
|
82
|
+
cat_embeds.append(cat_embed)
|
|
83
|
+
|
|
84
|
+
return cat_embeds
|
|
85
|
+
|
|
86
|
+
def _validate_categorical_inputs(self):
|
|
87
|
+
categorical_vocabulary_sizes = self.categorical_vocabulary_sizes
|
|
88
|
+
categorical_embedding_dims = self.categorical_embedding_dims
|
|
89
|
+
|
|
90
|
+
if not isinstance(categorical_vocabulary_sizes, list):
|
|
91
|
+
raise TypeError("categorical_vocabulary_sizes must be a list of int")
|
|
92
|
+
|
|
93
|
+
if isinstance(categorical_embedding_dims, list):
|
|
94
|
+
if len(categorical_vocabulary_sizes) != len(categorical_embedding_dims):
|
|
95
|
+
raise ValueError(
|
|
96
|
+
"Categorical vocabulary sizes and their embedding dimensions must have the same length"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
num_categorical_features = len(categorical_vocabulary_sizes)
|
|
100
|
+
|
|
101
|
+
# "Transform" embedding dims into a suitable list, or stay None
|
|
102
|
+
if categorical_embedding_dims is not None:
|
|
103
|
+
if isinstance(categorical_embedding_dims, int):
|
|
104
|
+
self.forward_type = CategoricalForwardType.AVERAGE_AND_CONCAT
|
|
105
|
+
self.output_dim = categorical_embedding_dims
|
|
106
|
+
categorical_embedding_dims = [categorical_embedding_dims] * num_categorical_features
|
|
107
|
+
|
|
108
|
+
elif isinstance(categorical_embedding_dims, list):
|
|
109
|
+
self.forward_type = CategoricalForwardType.CONCATENATE_ALL
|
|
110
|
+
self.output_dim = sum(categorical_embedding_dims)
|
|
111
|
+
else:
|
|
112
|
+
raise TypeError("categorical_embedding_dims must be an int, a list of int or None")
|
|
113
|
+
else:
|
|
114
|
+
if self.text_embedding_dim is None:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"If categorical_embedding_dims is None, text_embedding_dim must be provided"
|
|
117
|
+
)
|
|
118
|
+
self.forward_type = CategoricalForwardType.SUM_TO_TEXT
|
|
119
|
+
self.output_dim = self.text_embedding_dim
|
|
120
|
+
categorical_embedding_dims = [self.text_embedding_dim] * num_categorical_features
|
|
121
|
+
|
|
122
|
+
assert (
|
|
123
|
+
isinstance(categorical_embedding_dims, list) or categorical_embedding_dims is None
|
|
124
|
+
), "categorical_embedding_dims must be a list of int at this point"
|
|
125
|
+
|
|
126
|
+
self.categorical_vocabulary_sizes = categorical_vocabulary_sizes
|
|
127
|
+
self.categorical_embedding_dims = categorical_embedding_dims
|
|
128
|
+
self.num_categorical_features = num_categorical_features
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class ClassificationHead(nn.Module):
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
input_dim: Optional[int] = None,
|
|
11
|
+
num_classes: Optional[int] = None,
|
|
12
|
+
net: Optional[nn.Module] = None,
|
|
13
|
+
):
|
|
14
|
+
super().__init__()
|
|
15
|
+
if net is not None:
|
|
16
|
+
self.net = net
|
|
17
|
+
self.input_dim = net.in_features
|
|
18
|
+
self.num_classes = net.out_features
|
|
19
|
+
else:
|
|
20
|
+
assert (
|
|
21
|
+
input_dim is not None and num_classes is not None
|
|
22
|
+
), "Either net or both input_dim and num_classes must be provided."
|
|
23
|
+
self.net = nn.Linear(input_dim, num_classes)
|
|
24
|
+
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
|
|
25
|
+
|
|
26
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
return self.net(x)
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def _get_linear_input_output_dims(module: nn.Module):
|
|
31
|
+
"""
|
|
32
|
+
Returns (input_dim, output_dim) for any module containing Linear layers.
|
|
33
|
+
Works for Linear, Sequential, or nested models.
|
|
34
|
+
"""
|
|
35
|
+
# Collect all Linear layers recursively
|
|
36
|
+
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]
|
|
37
|
+
|
|
38
|
+
if not linears:
|
|
39
|
+
raise ValueError("No Linear layers found in the given module.")
|
|
40
|
+
|
|
41
|
+
input_dim = linears[0].in_features
|
|
42
|
+
output_dim = linears[-1].out_features
|
|
43
|
+
return input_dim, output_dim
|