deep-cross-attention 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,171 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,103 @@
1
+ Metadata-Version: 2.4
2
+ Name: deep-cross-attention
3
+ Version: 0.0.2
4
+ Summary: Deep Cross Attention Language Model
5
+ Project-URL: Homepage, https://pypi.org/project/deep-cross-attention/
6
+ Project-URL: Repository, https://github.com/lucidrains/deep-cross-attention
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,residuals,transformers
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: rotary-embedding-torch
39
+ Requires-Dist: torch>=2.3
40
+ Provides-Extra: examples
41
+ Requires-Dist: tqdm; extra == 'examples'
42
+ Provides-Extra: test
43
+ Requires-Dist: pytest; extra == 'test'
44
+ Description-Content-Type: text/markdown
45
+
46
+ <img src="./fig4.png" width="400px"></img>
47
+
48
+ ## Deep Cross Attention
49
+
50
+ Implementation of the proposed [DeepCrossAttention](https://arxiv.org/abs/2502.06785) by [Mike Heddes](https://www.mikeheddes.nl/) while at Google research, in Pytorch
51
+
52
+ My analysis is although I still prefer [Hyper Connections](https://arxiv.org/abs/2409.19606), they have an important idea here that I have been trying concurrently. Mainly the queries, keys, values can be [routed from different layers](https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L1226) of the past
53
+
54
+ ## Install
55
+
56
+ ```bash
57
+ $ pip install deep-cross-attention
58
+ ```
59
+
60
+ ## Usage
61
+
62
+ ```python
63
+ import torch
64
+ from deep_cross_attention import DCAGPT
65
+
66
+ gpt = DCAGPT(
67
+ num_tokens = 256,
68
+ dim = 512,
69
+ depth = 6,
70
+ heads = 8,
71
+ dim_head = 64,
72
+ past_layers_k = 2
73
+ )
74
+
75
+ ids = torch.randint(0, 256, (2, 4096))
76
+
77
+ logits = gpt(ids) # (2, 4096, 256)
78
+ ```
79
+
80
+ ## Example
81
+
82
+ First
83
+
84
+ ```bash
85
+ $ pip install .[examples]
86
+ ```
87
+
88
+ Next
89
+
90
+ ```bash
91
+ $ python train.py
92
+ ```
93
+
94
+ ## Citations
95
+
96
+ ```bibtex
97
+ @inproceedings{Heddes2025DeepCrossAttentionST,
98
+ title = {DeepCrossAttention: Supercharging Transformer Residual Connections},
99
+ author = {Mike Heddes and Adel Javanmard and Kyriakos Axiotis and Gang Fu and MohammadHossein Bateni and Vahab S. Mirrokni},
100
+ year = {2025},
101
+ url = {https://api.semanticscholar.org/CorpusID:276250576}
102
+ }
103
+ ```
@@ -0,0 +1,58 @@
1
+ <img src="./fig4.png" width="400px"></img>
2
+
3
+ ## Deep Cross Attention
4
+
5
+ Implementation of the proposed [DeepCrossAttention](https://arxiv.org/abs/2502.06785) by [Mike Heddes](https://www.mikeheddes.nl/) while at Google research, in Pytorch
6
+
7
+ My analysis is although I still prefer [Hyper Connections](https://arxiv.org/abs/2409.19606), they have an important idea here that I have been trying concurrently. Mainly the queries, keys, values can be [routed from different layers](https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L1226) of the past
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ $ pip install deep-cross-attention
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ import torch
19
+ from deep_cross_attention import DCAGPT
20
+
21
+ gpt = DCAGPT(
22
+ num_tokens = 256,
23
+ dim = 512,
24
+ depth = 6,
25
+ heads = 8,
26
+ dim_head = 64,
27
+ past_layers_k = 2
28
+ )
29
+
30
+ ids = torch.randint(0, 256, (2, 4096))
31
+
32
+ logits = gpt(ids) # (2, 4096, 256)
33
+ ```
34
+
35
+ ## Example
36
+
37
+ First
38
+
39
+ ```bash
40
+ $ pip install .[examples]
41
+ ```
42
+
43
+ Next
44
+
45
+ ```bash
46
+ $ python train.py
47
+ ```
48
+
49
+ ## Citations
50
+
51
+ ```bibtex
52
+ @inproceedings{Heddes2025DeepCrossAttentionST,
53
+ title = {DeepCrossAttention: Supercharging Transformer Residual Connections},
54
+ author = {Mike Heddes and Adel Javanmard and Kyriakos Axiotis and Gang Fu and MohammadHossein Bateni and Vahab S. Mirrokni},
55
+ year = {2025},
56
+ url = {https://api.semanticscholar.org/CorpusID:276250576}
57
+ }
58
+ ```
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
@@ -0,0 +1,2 @@
1
+ from deep_cross_attention.gpt import GPT
2
+ from deep_cross_attention.dca_gpt import DCAGPT
@@ -0,0 +1,230 @@
1
+ import torch
2
+ from torch import nn, cat, stack
3
+ import torch.nn.functional as F
4
+ from torch.nn import Module, ModuleList, Linear, RMSNorm
5
+
6
+ from einops import rearrange, einsum
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from rotary_embedding_torch import RotaryEmbedding
10
+
11
+ # functions
12
+
13
+ def exists(v):
14
+ return v is not None
15
+
16
+ def default(v, d):
17
+ return v if exists(v) else d
18
+
19
+ # attention
20
+
21
+ class Attention(Module):
22
+ def __init__(
23
+ self,
24
+ dim,
25
+ dim_head = 64,
26
+ heads = 8
27
+ ):
28
+ super().__init__()
29
+ self.norm = RMSNorm(dim)
30
+
31
+ self.heads = heads
32
+ dim_inner = heads * dim_head
33
+
34
+ self.rotary_embed = RotaryEmbedding(dim_head)
35
+
36
+ self.to_q = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner, bias = False))
37
+ self.to_k = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner, bias = False))
38
+ self.to_v = nn.Sequential(RMSNorm(dim), nn.Linear(dim, dim_inner, bias = False))
39
+
40
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
41
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
42
+
43
+ self.to_out = nn.Linear(dim_inner, dim, bias = False)
44
+
45
+ def forward(
46
+ self,
47
+ q_input,
48
+ k_input,
49
+ v_input
50
+ ):
51
+
52
+ q = self.to_q(q_input)
53
+ k = self.to_k(k_input)
54
+ v = self.to_v(v_input)
55
+
56
+ q, k, v = map(self.split_heads, (q, k, v))
57
+
58
+ # relative positions
59
+
60
+ q, k = self.rotary_embed.rotate_queries_with_cached_keys(q, k)
61
+
62
+ # attention branch
63
+
64
+ out = F.scaled_dot_product_attention(
65
+ q, k, v,
66
+ is_causal = True
67
+ )
68
+
69
+ out = self.merge_heads(out)
70
+
71
+ return self.to_out(out)
72
+
73
+ # feedforward
74
+
75
+ def FeedForward(dim, expansion_factor = 4.):
76
+ dim_hidden = int(dim * expansion_factor)
77
+
78
+ return nn.Sequential(
79
+ Linear(dim, dim_hidden),
80
+ nn.GELU(),
81
+ Linear(dim_hidden, dim)
82
+ )
83
+
84
+ # GRNv3
85
+ # the input dependent one lines up with all the literature, and the winning solution for hyper connections (dynamic)
86
+
87
+ class GRN(Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.to_aggregate = nn.Sequential(
95
+ RMSNorm(dim),
96
+ Linear(dim, 1),
97
+ nn.ReLU(),
98
+ )
99
+
100
+ nn.init.zeros_(self.to_aggregate[-2].weight)
101
+
102
+ def forward(
103
+ self,
104
+ tokens_across_depth # Float['depth b n d']
105
+ ):
106
+ aggregate = self.to_aggregate(tokens_across_depth)
107
+
108
+ return (tokens_across_depth * aggregate).sum(dim = 0)
109
+
110
+ # DCA Decoder Block
111
+
112
+ class DCABlock(Module):
113
+ def __init__(
114
+ self,
115
+ dim,
116
+ dim_head = 64,
117
+ heads = 8,
118
+ ff_expansion_factor = 4.
119
+ ):
120
+ super().__init__()
121
+
122
+ self.q_grn = GRN(dim)
123
+ self.k_grn = GRN(dim)
124
+ self.v_grn = GRN(dim)
125
+
126
+ self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads)
127
+
128
+ self.pre_ff_norm = RMSNorm(dim)
129
+
130
+ self.ff = FeedForward(dim = dim, expansion_factor = ff_expansion_factor)
131
+
132
+ def forward(
133
+ self,
134
+ tokens_across_depth # Float['depth b n d']
135
+ ):
136
+ q_input, k_input, v_input = self.q_grn(tokens_across_depth), self.k_grn(tokens_across_depth), self.v_grn(tokens_across_depth)
137
+
138
+ residual = q_input
139
+
140
+ attn_out = self.attn(q_input, k_input, v_input)
141
+
142
+ ff_input = self.pre_ff_norm(attn_out + residual)
143
+
144
+ ff_out = self.ff(ff_input)
145
+
146
+ return ff_out + attn_out
147
+
148
+ # classes
149
+
150
+ class DCAGPT(Module):
151
+ def __init__(
152
+ self,
153
+ num_tokens,
154
+ dim,
155
+ depth,
156
+ past_layers_k = 2,
157
+ dim_head = 64,
158
+ heads = 8,
159
+ ff_expansion_factor = 4.
160
+ ):
161
+ super().__init__()
162
+ self.token_emb = nn.Embedding(num_tokens, dim)
163
+
164
+ # the `k` hyperparameter, which seems to refer to sub sampling of which layers to include for efficiency
165
+ # but weirdly, they not only do last k layers, but also the first k? also some mention about intermediate layers being pooled? just go with first and last for now
166
+
167
+ self.past_layers_k = past_layers_k
168
+
169
+ # the proposed DCA blocks
170
+
171
+ dca_blocks = []
172
+ for _ in range(depth):
173
+ dca = DCABlock(dim = dim, dim_head = dim_head, heads = heads, ff_expansion_factor = ff_expansion_factor)
174
+
175
+ dca_blocks.append(dca)
176
+
177
+ self.dca_blocks = ModuleList(dca_blocks)
178
+
179
+ # norm and logits
180
+
181
+ self.final_grn = GRN(dim)
182
+
183
+ self.norm = RMSNorm(dim)
184
+ self.to_logits = Linear(dim, num_tokens, bias = False)
185
+
186
+ def forward(
187
+ self,
188
+ ids,
189
+ return_loss = False
190
+ ):
191
+ k = self.past_layers_k # k in paper
192
+
193
+ if return_loss:
194
+ ids, labels = ids[:, :-1], ids[:, 1:]
195
+
196
+ tokens = self.token_emb(ids)
197
+
198
+ all_tokens = [tokens]
199
+
200
+ for dca_block in self.dca_blocks:
201
+
202
+ all_tokens_stacked = stack(all_tokens)
203
+ num_layers = all_tokens_stacked.shape[0]
204
+
205
+ # determine which layers to include
206
+
207
+ if num_layers < (k * 2):
208
+ dca_block_input = all_tokens_stacked
209
+ else:
210
+ dca_block_input = cat((
211
+ all_tokens_stacked[:k], # first k layers
212
+ all_tokens_stacked[-k:] # last k layers
213
+ ))
214
+
215
+ dca_out = dca_block(dca_block_input)
216
+
217
+ # append dca output for next iteration
218
+
219
+ all_tokens.append(dca_out)
220
+
221
+ pooled_tokens = self.final_grn(stack(all_tokens))
222
+
223
+ embed = self.norm(pooled_tokens)
224
+
225
+ logits = self.to_logits(embed)
226
+
227
+ if not return_loss:
228
+ return logits
229
+
230
+ return F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
Binary file
@@ -0,0 +1,61 @@
1
+ [project]
2
+ name = "deep-cross-attention"
3
+ version = "0.0.2"
4
+ description = "Deep Cross Attention Language Model"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'transformers',
15
+ 'residuals'
16
+ ]
17
+
18
+ classifiers=[
19
+ 'Development Status :: 4 - Beta',
20
+ 'Intended Audience :: Developers',
21
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22
+ 'License :: OSI Approved :: MIT License',
23
+ 'Programming Language :: Python :: 3.9',
24
+ ]
25
+
26
+ dependencies = [
27
+ "einops>=0.8.0",
28
+ "rotary-embedding-torch",
29
+ "torch>=2.3",
30
+ ]
31
+
32
+ [project.urls]
33
+ Homepage = "https://pypi.org/project/deep-cross-attention/"
34
+ Repository = "https://github.com/lucidrains/deep-cross-attention"
35
+
36
+ [project.optional-dependencies]
37
+ examples = [
38
+ "tqdm"
39
+ ]
40
+ test = [
41
+ "pytest"
42
+ ]
43
+
44
+ [tool.pytest.ini_options]
45
+ pythonpath = [
46
+ "."
47
+ ]
48
+
49
+ [build-system]
50
+ requires = ["hatchling"]
51
+ build-backend = "hatchling.build"
52
+
53
+ [tool.rye]
54
+ managed = true
55
+ dev-dependencies = []
56
+
57
+ [tool.hatch.metadata]
58
+ allow-direct-references = true
59
+
60
+ [tool.hatch.build.targets.wheel]
61
+ packages = ["deep_cross_attention"]
@@ -0,0 +1,179 @@
1
+ import math
2
+ import gzip
3
+ import random
4
+ import tqdm
5
+ import numpy as np
6
+
7
+ import torch
8
+ from torch.optim import Adam
9
+ from torch import Tensor
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from deep_cross_attention import (
13
+ GPT,
14
+ DCAGPT
15
+ )
16
+
17
+ # constants
18
+
19
+ NUM_BATCHES = int(1e5)
20
+ BATCH_SIZE = 4
21
+ GRAD_ACCUM_EVERY = 4
22
+ LEARNING_RATE = 1e-4
23
+ VALIDATE_EVERY = 100
24
+ PRIME_LENGTH = 128
25
+ GENERATE_EVERY = 500
26
+ GENERATE_LENGTH = 512
27
+ SEQ_LEN = 512
28
+
29
+ USE_DCA = True
30
+
31
+ # helpers
32
+
33
+ def exists(v):
34
+ return v is not None
35
+
36
+ def cycle(loader):
37
+ while True:
38
+ for data in loader:
39
+ yield data
40
+
41
+ def decode_token(token):
42
+ return str(chr(max(32, token)))
43
+
44
+ def decode_tokens(tokens):
45
+ return "".join(list(map(decode_token, tokens)))
46
+
47
+ # sampling helpers
48
+
49
+ def log(t, eps = 1e-20):
50
+ return torch.log(t.clamp(min = eps))
51
+
52
+ def gumbel_noise(t):
53
+ noise = torch.zeros_like(t).uniform_(0, 1)
54
+ return -log(-log(noise))
55
+
56
+ def gumbel_sample(t, temperature = 1., dim = -1, keepdim = True):
57
+ return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim, keepdim = keepdim)
58
+
59
+ def top_k(logits, thres = 0.9):
60
+ k = math.ceil((1 - thres) * logits.shape[-1])
61
+ val, ind = torch.topk(logits, k)
62
+ probs = torch.full_like(logits, float('-inf'))
63
+ probs.scatter_(-1, ind, val)
64
+ return probs
65
+
66
+ def base_decoding(
67
+ net,
68
+ prompt: Tensor,
69
+ seq_len: int,
70
+ temperature = 1.,
71
+ filter_thres = 0.9,
72
+ ):
73
+ prompt_seq_len, out = prompt.shape[-1], prompt.clone()
74
+ sample_num_times = max(0, seq_len - prompt_seq_len)
75
+
76
+ for _ in range(sample_num_times):
77
+ logits = net(out)
78
+ logits = logits[:, -1]
79
+ logits = top_k(logits, thres = filter_thres)
80
+ sample = gumbel_sample(logits, temperature = temperature, dim = -1)
81
+
82
+ out = torch.cat((out, sample), dim = -1)
83
+
84
+ return out[..., prompt_seq_len:]
85
+
86
+ # model
87
+
88
+ if USE_DCA:
89
+ model = DCAGPT(
90
+ num_tokens = 256,
91
+ dim = 512,
92
+ depth = 6,
93
+ past_layers_k = 2 # the `k` value in their paper, which refers to how many layers in the past it looks, a la Denseformer
94
+ )
95
+ else:
96
+ model = GPT(
97
+ num_tokens = 256,
98
+ dim = 512,
99
+ depth = 6
100
+ )
101
+
102
+ model = model.cuda()
103
+
104
+ # prepare enwik8 data
105
+
106
+ with gzip.open('./data/enwik8.gz') as file:
107
+ data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
108
+ np_train, np_valid = np.split(data, [int(90e6)])
109
+ data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
110
+
111
+ class TextSamplerDataset(Dataset):
112
+ def __init__(self, data, seq_len):
113
+ super().__init__()
114
+ self.data = data
115
+ self.seq_len = seq_len
116
+
117
+ def __len__(self):
118
+ return self.data.size(0) // self.seq_len
119
+
120
+ def __getitem__(self, index):
121
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
122
+ full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
123
+ return full_seq.cuda()
124
+
125
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
126
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
127
+ train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
128
+ val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE)
129
+
130
+ # optimizer
131
+
132
+ optim = Adam(model.parameters(), lr = LEARNING_RATE)
133
+
134
+ train_loader = cycle(train_loader)
135
+ val_loader = cycle(val_loader)
136
+
137
+ # training
138
+
139
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10.0, desc = "training"):
140
+ model.train()
141
+
142
+ for _ in range(GRAD_ACCUM_EVERY):
143
+ data = next(train_loader)
144
+
145
+ loss = model(data, return_loss = True)
146
+
147
+ (loss / GRAD_ACCUM_EVERY).backward()
148
+
149
+ print(f"training loss: {loss.item():.3f}")
150
+
151
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
152
+
153
+ optim.step()
154
+ optim.zero_grad()
155
+
156
+ if i % VALIDATE_EVERY == 0:
157
+ model.eval()
158
+ with torch.no_grad():
159
+ valid_data = next(val_loader)
160
+
161
+ loss = model(valid_data, return_loss = True)
162
+ print(f"validation loss: {loss.item():.3f}")
163
+
164
+ if i % GENERATE_EVERY == 0:
165
+ model.eval()
166
+
167
+ inp = random.choice(val_dataset)[:PRIME_LENGTH]
168
+ inp = inp.cuda()
169
+
170
+ prime = decode_tokens(inp)
171
+ print(f"\n{prime}\n")
172
+
173
+ prompt = inp[None, ...]
174
+
175
+ sampled = base_decoding(model, prompt, GENERATE_LENGTH)
176
+
177
+ base_decode_output = decode_tokens(sampled[0])
178
+
179
+ print(f"\n{base_decode_output}\n")