titans-pytorch 0.0.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch-0.0.14/.github/workflows/python-publish.yml +36 -0
- titans_pytorch-0.0.14/.gitignore +173 -0
- titans_pytorch-0.0.14/LICENSE +21 -0
- titans_pytorch-0.0.14/PKG-INFO +111 -0
- titans_pytorch-0.0.14/README.md +60 -0
- titans_pytorch-0.0.14/data/README.md +3 -0
- titans_pytorch-0.0.14/data/enwik8.gz +0 -0
- titans_pytorch-0.0.14/fig1.png +0 -0
- titans_pytorch-0.0.14/fig2.png +0 -0
- titans_pytorch-0.0.14/pyproject.toml +70 -0
- titans_pytorch-0.0.14/requirements.txt +1 -0
- titans_pytorch-0.0.14/titans_pytorch/__init__.py +3 -0
- titans_pytorch-0.0.14/titans_pytorch/associative_scan.py +90 -0
- titans_pytorch-0.0.14/titans_pytorch/titans.py +408 -0
- titans_pytorch-0.0.14/train.py +151 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
# This workflow will upload a Python Package using Twine when a release is created
|
2
|
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
3
|
+
|
4
|
+
# This workflow uses actions that are not certified by GitHub.
|
5
|
+
# They are provided by a third-party and are governed by
|
6
|
+
# separate terms of service, privacy policy, and support
|
7
|
+
# documentation.
|
8
|
+
|
9
|
+
name: Upload Python Package
|
10
|
+
|
11
|
+
on:
|
12
|
+
release:
|
13
|
+
types: [published]
|
14
|
+
|
15
|
+
jobs:
|
16
|
+
deploy:
|
17
|
+
|
18
|
+
runs-on: ubuntu-latest
|
19
|
+
|
20
|
+
steps:
|
21
|
+
- uses: actions/checkout@v2
|
22
|
+
- name: Set up Python
|
23
|
+
uses: actions/setup-python@v2
|
24
|
+
with:
|
25
|
+
python-version: '3.x'
|
26
|
+
- name: Install dependencies
|
27
|
+
run: |
|
28
|
+
python -m pip install --upgrade pip
|
29
|
+
pip install build
|
30
|
+
- name: Build package
|
31
|
+
run: python -m build
|
32
|
+
- name: Publish package
|
33
|
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
34
|
+
with:
|
35
|
+
user: __token__
|
36
|
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
@@ -0,0 +1,173 @@
|
|
1
|
+
train_local.py
|
2
|
+
|
3
|
+
# Byte-compiled / optimized / DLL files
|
4
|
+
__pycache__/
|
5
|
+
*.py[cod]
|
6
|
+
*$py.class
|
7
|
+
|
8
|
+
# C extensions
|
9
|
+
*.so
|
10
|
+
|
11
|
+
# Distribution / packaging
|
12
|
+
.Python
|
13
|
+
build/
|
14
|
+
develop-eggs/
|
15
|
+
dist/
|
16
|
+
downloads/
|
17
|
+
eggs/
|
18
|
+
.eggs/
|
19
|
+
lib/
|
20
|
+
lib64/
|
21
|
+
parts/
|
22
|
+
sdist/
|
23
|
+
var/
|
24
|
+
wheels/
|
25
|
+
share/python-wheels/
|
26
|
+
*.egg-info/
|
27
|
+
.installed.cfg
|
28
|
+
*.egg
|
29
|
+
MANIFEST
|
30
|
+
|
31
|
+
# PyInstaller
|
32
|
+
# Usually these files are written by a python script from a template
|
33
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34
|
+
*.manifest
|
35
|
+
*.spec
|
36
|
+
|
37
|
+
# Installer logs
|
38
|
+
pip-log.txt
|
39
|
+
pip-delete-this-directory.txt
|
40
|
+
|
41
|
+
# Unit test / coverage reports
|
42
|
+
htmlcov/
|
43
|
+
.tox/
|
44
|
+
.nox/
|
45
|
+
.coverage
|
46
|
+
.coverage.*
|
47
|
+
.cache
|
48
|
+
nosetests.xml
|
49
|
+
coverage.xml
|
50
|
+
*.cover
|
51
|
+
*.py,cover
|
52
|
+
.hypothesis/
|
53
|
+
.pytest_cache/
|
54
|
+
cover/
|
55
|
+
|
56
|
+
# Translations
|
57
|
+
*.mo
|
58
|
+
*.pot
|
59
|
+
|
60
|
+
# Django stuff:
|
61
|
+
*.log
|
62
|
+
local_settings.py
|
63
|
+
db.sqlite3
|
64
|
+
db.sqlite3-journal
|
65
|
+
|
66
|
+
# Flask stuff:
|
67
|
+
instance/
|
68
|
+
.webassets-cache
|
69
|
+
|
70
|
+
# Scrapy stuff:
|
71
|
+
.scrapy
|
72
|
+
|
73
|
+
# Sphinx documentation
|
74
|
+
docs/_build/
|
75
|
+
|
76
|
+
# PyBuilder
|
77
|
+
.pybuilder/
|
78
|
+
target/
|
79
|
+
|
80
|
+
# Jupyter Notebook
|
81
|
+
.ipynb_checkpoints
|
82
|
+
|
83
|
+
# IPython
|
84
|
+
profile_default/
|
85
|
+
ipython_config.py
|
86
|
+
|
87
|
+
# pyenv
|
88
|
+
# For a library or package, you might want to ignore these files since the code is
|
89
|
+
# intended to run in multiple environments; otherwise, check them in:
|
90
|
+
# .python-version
|
91
|
+
|
92
|
+
# pipenv
|
93
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96
|
+
# install all needed dependencies.
|
97
|
+
#Pipfile.lock
|
98
|
+
|
99
|
+
# UV
|
100
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
101
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102
|
+
# commonly ignored for libraries.
|
103
|
+
#uv.lock
|
104
|
+
|
105
|
+
# poetry
|
106
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
107
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
108
|
+
# commonly ignored for libraries.
|
109
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
110
|
+
#poetry.lock
|
111
|
+
|
112
|
+
# pdm
|
113
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
114
|
+
#pdm.lock
|
115
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
116
|
+
# in version control.
|
117
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
118
|
+
.pdm.toml
|
119
|
+
.pdm-python
|
120
|
+
.pdm-build/
|
121
|
+
|
122
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
123
|
+
__pypackages__/
|
124
|
+
|
125
|
+
# Celery stuff
|
126
|
+
celerybeat-schedule
|
127
|
+
celerybeat.pid
|
128
|
+
|
129
|
+
# SageMath parsed files
|
130
|
+
*.sage.py
|
131
|
+
|
132
|
+
# Environments
|
133
|
+
.env
|
134
|
+
.venv
|
135
|
+
env/
|
136
|
+
venv/
|
137
|
+
ENV/
|
138
|
+
env.bak/
|
139
|
+
venv.bak/
|
140
|
+
|
141
|
+
# Spyder project settings
|
142
|
+
.spyderproject
|
143
|
+
.spyproject
|
144
|
+
|
145
|
+
# Rope project settings
|
146
|
+
.ropeproject
|
147
|
+
|
148
|
+
# mkdocs documentation
|
149
|
+
/site
|
150
|
+
|
151
|
+
# mypy
|
152
|
+
.mypy_cache/
|
153
|
+
.dmypy.json
|
154
|
+
dmypy.json
|
155
|
+
|
156
|
+
# Pyre type checker
|
157
|
+
.pyre/
|
158
|
+
|
159
|
+
# pytype static type analyzer
|
160
|
+
.pytype/
|
161
|
+
|
162
|
+
# Cython debug symbols
|
163
|
+
cython_debug/
|
164
|
+
|
165
|
+
# PyCharm
|
166
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
167
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
168
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
169
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
170
|
+
#.idea/
|
171
|
+
|
172
|
+
# PyPI configuration file
|
173
|
+
.pypirc
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Phil Wang
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,111 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: titans-pytorch
|
3
|
+
Version: 0.0.14
|
4
|
+
Summary: Titans
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
8
|
+
License: MIT License
|
9
|
+
|
10
|
+
Copyright (c) 2025 Phil Wang
|
11
|
+
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
14
|
+
in the Software without restriction, including without limitation the rights
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
17
|
+
furnished to do so, subject to the following conditions:
|
18
|
+
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
20
|
+
copies or substantial portions of the Software.
|
21
|
+
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
|
+
SOFTWARE.
|
29
|
+
License-File: LICENSE
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
32
|
+
Classifier: Intended Audience :: Developers
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
34
|
+
Classifier: Programming Language :: Python :: 3.9
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
|
+
Requires-Python: >=3.9
|
37
|
+
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: einops>=0.8.0
|
39
|
+
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: ninja
|
41
|
+
Requires-Dist: tensordict
|
42
|
+
Requires-Dist: torch>=2.2
|
43
|
+
Provides-Extra: examples
|
44
|
+
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
45
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
46
|
+
Requires-Dist: tqdm; extra == 'examples'
|
47
|
+
Requires-Dist: wandb; extra == 'examples'
|
48
|
+
Provides-Extra: test
|
49
|
+
Requires-Dist: pytest; extra == 'test'
|
50
|
+
Description-Content-Type: text/markdown
|
51
|
+
|
52
|
+
<img src="./fig2.png" width="400px"></img>
|
53
|
+
|
54
|
+
<img src="./fig1.png" width="400px"></img>
|
55
|
+
|
56
|
+
## Titans - Pytorch (wip)
|
57
|
+
|
58
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
59
|
+
|
60
|
+
## Install
|
61
|
+
|
62
|
+
```bash
|
63
|
+
$ pip install titans-pytorch
|
64
|
+
```
|
65
|
+
|
66
|
+
## Usage
|
67
|
+
|
68
|
+
```python
|
69
|
+
import torch
|
70
|
+
from titans_pytorch import NeuralMemory
|
71
|
+
|
72
|
+
mem = NeuralMemory(
|
73
|
+
dim = 384,
|
74
|
+
chunk_size = 64,
|
75
|
+
pre_rmsnorm = True
|
76
|
+
).cuda()
|
77
|
+
|
78
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
79
|
+
retrieved = mem(seq)
|
80
|
+
|
81
|
+
assert seq.shape == retrieved.shape
|
82
|
+
```
|
83
|
+
|
84
|
+
## Experiments
|
85
|
+
|
86
|
+
```bash
|
87
|
+
$ pip install .[examples]
|
88
|
+
```
|
89
|
+
|
90
|
+
For the SOTA linear attention, you will also need to run
|
91
|
+
|
92
|
+
```bash
|
93
|
+
$ pip install -r requirements.txt
|
94
|
+
```
|
95
|
+
|
96
|
+
Then modify `train.py` and run it to query nature
|
97
|
+
|
98
|
+
```bash
|
99
|
+
$ python train.py
|
100
|
+
```
|
101
|
+
|
102
|
+
## Citations
|
103
|
+
|
104
|
+
```bibtex
|
105
|
+
@inproceedings{Behrouz2024TitansLT,
|
106
|
+
title = {Titans: Learning to Memorize at Test Time},
|
107
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
108
|
+
year = {2024},
|
109
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
|
+
}
|
111
|
+
```
|
@@ -0,0 +1,60 @@
|
|
1
|
+
<img src="./fig2.png" width="400px"></img>
|
2
|
+
|
3
|
+
<img src="./fig1.png" width="400px"></img>
|
4
|
+
|
5
|
+
## Titans - Pytorch (wip)
|
6
|
+
|
7
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
|
+
|
9
|
+
## Install
|
10
|
+
|
11
|
+
```bash
|
12
|
+
$ pip install titans-pytorch
|
13
|
+
```
|
14
|
+
|
15
|
+
## Usage
|
16
|
+
|
17
|
+
```python
|
18
|
+
import torch
|
19
|
+
from titans_pytorch import NeuralMemory
|
20
|
+
|
21
|
+
mem = NeuralMemory(
|
22
|
+
dim = 384,
|
23
|
+
chunk_size = 64,
|
24
|
+
pre_rmsnorm = True
|
25
|
+
).cuda()
|
26
|
+
|
27
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
28
|
+
retrieved = mem(seq)
|
29
|
+
|
30
|
+
assert seq.shape == retrieved.shape
|
31
|
+
```
|
32
|
+
|
33
|
+
## Experiments
|
34
|
+
|
35
|
+
```bash
|
36
|
+
$ pip install .[examples]
|
37
|
+
```
|
38
|
+
|
39
|
+
For the SOTA linear attention, you will also need to run
|
40
|
+
|
41
|
+
```bash
|
42
|
+
$ pip install -r requirements.txt
|
43
|
+
```
|
44
|
+
|
45
|
+
Then modify `train.py` and run it to query nature
|
46
|
+
|
47
|
+
```bash
|
48
|
+
$ python train.py
|
49
|
+
```
|
50
|
+
|
51
|
+
## Citations
|
52
|
+
|
53
|
+
```bibtex
|
54
|
+
@inproceedings{Behrouz2024TitansLT,
|
55
|
+
title = {Titans: Learning to Memorize at Test Time},
|
56
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
57
|
+
year = {2024},
|
58
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
59
|
+
}
|
60
|
+
```
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,70 @@
|
|
1
|
+
[project]
|
2
|
+
name = "titans-pytorch"
|
3
|
+
version = "0.0.14"
|
4
|
+
description = "Titans"
|
5
|
+
authors = [
|
6
|
+
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
7
|
+
]
|
8
|
+
readme = "README.md"
|
9
|
+
requires-python = ">= 3.9"
|
10
|
+
license = { file = "LICENSE" }
|
11
|
+
keywords = [
|
12
|
+
'artificial intelligence',
|
13
|
+
'deep learning',
|
14
|
+
'neural memory module',
|
15
|
+
'test time training',
|
16
|
+
'linear attention'
|
17
|
+
]
|
18
|
+
|
19
|
+
classifiers=[
|
20
|
+
'Development Status :: 4 - Beta',
|
21
|
+
'Intended Audience :: Developers',
|
22
|
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
23
|
+
'License :: OSI Approved :: MIT License',
|
24
|
+
'Programming Language :: Python :: 3.9',
|
25
|
+
]
|
26
|
+
|
27
|
+
dependencies = [
|
28
|
+
"accelerated-scan>=0.2.0",
|
29
|
+
"einx>=0.3.0",
|
30
|
+
"einops>=0.8.0",
|
31
|
+
"Ninja",
|
32
|
+
"tensordict",
|
33
|
+
"torch>=2.2",
|
34
|
+
]
|
35
|
+
|
36
|
+
[project.urls]
|
37
|
+
Homepage = "https://pypi.org/project/titans-pytorch/"
|
38
|
+
Repository = "https://github.com/lucidrains/titans-pytorch"
|
39
|
+
|
40
|
+
[project.optional-dependencies]
|
41
|
+
|
42
|
+
examples = [
|
43
|
+
"local-attention>=1.10.1",
|
44
|
+
"taylor-series-linear-attention",
|
45
|
+
"tqdm",
|
46
|
+
"wandb"
|
47
|
+
]
|
48
|
+
|
49
|
+
test = [
|
50
|
+
"pytest"
|
51
|
+
]
|
52
|
+
|
53
|
+
[tool.pytest.ini_options]
|
54
|
+
pythonpath = [
|
55
|
+
"."
|
56
|
+
]
|
57
|
+
|
58
|
+
[build-system]
|
59
|
+
requires = ["hatchling"]
|
60
|
+
build-backend = "hatchling.build"
|
61
|
+
|
62
|
+
[tool.rye]
|
63
|
+
managed = true
|
64
|
+
dev-dependencies = []
|
65
|
+
|
66
|
+
[tool.hatch.metadata]
|
67
|
+
allow-direct-references = true
|
68
|
+
|
69
|
+
[tool.hatch.build.targets.wheel]
|
70
|
+
packages = ["titans_pytorch"]
|
@@ -0,0 +1 @@
|
|
1
|
+
pytorch-fast-transformers>=0.4.0
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import Tensor
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
# taken from S5-pytorch repository
|
9
|
+
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
13
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
|
+
zeros = ((0, 0) * dims_from_right)
|
16
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
17
|
+
|
18
|
+
# the operator that is needed
|
19
|
+
|
20
|
+
@torch.jit.script
|
21
|
+
def binary_operator(
|
22
|
+
a: tuple[Tensor, Tensor],
|
23
|
+
b: tuple[Tensor, Tensor]
|
24
|
+
):
|
25
|
+
a_i, kv_i = a
|
26
|
+
a_j, kv_j = b
|
27
|
+
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
28
|
+
|
29
|
+
# Pytorch impl. of jax.lax.associative_scan
|
30
|
+
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
31
|
+
|
32
|
+
def associative_scan(
|
33
|
+
operator: Callable,
|
34
|
+
elems: tuple[Tensor, Tensor]
|
35
|
+
):
|
36
|
+
num_elems = int(elems[0].shape[1])
|
37
|
+
|
38
|
+
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
39
|
+
raise ValueError('Array inputs to associative_scan must have the same '
|
40
|
+
'first dimension. (saw: {})'
|
41
|
+
.format([elem.shape for elem in elems]))
|
42
|
+
|
43
|
+
def _scan(elems):
|
44
|
+
"""Perform scan on `elems`."""
|
45
|
+
num_elems = elems[0].shape[1]
|
46
|
+
|
47
|
+
if num_elems < 2:
|
48
|
+
return elems
|
49
|
+
|
50
|
+
# Combine adjacent pairs of elements.
|
51
|
+
|
52
|
+
reduced_elems = operator(
|
53
|
+
[elem[:, :-1:2] for elem in elems],
|
54
|
+
[elem[:, 1::2] for elem in elems])
|
55
|
+
|
56
|
+
# Recursively compute scan for partially reduced tensors.
|
57
|
+
|
58
|
+
odd_elems = _scan(reduced_elems)
|
59
|
+
|
60
|
+
if num_elems % 2 == 0:
|
61
|
+
even_elems = operator(
|
62
|
+
[e[:, :-1] for e in odd_elems],
|
63
|
+
[e[:, 2::2] for e in elems])
|
64
|
+
else:
|
65
|
+
even_elems = operator(
|
66
|
+
odd_elems,
|
67
|
+
[e[:, 2::2] for e in elems])
|
68
|
+
|
69
|
+
# The first element of a scan is the same as the first element
|
70
|
+
# of the original `elems`.
|
71
|
+
|
72
|
+
even_elems = [
|
73
|
+
torch.cat([elem[:, :1], result], dim=1)
|
74
|
+
for (elem, result) in zip(elems, even_elems)]
|
75
|
+
|
76
|
+
return list(map(_interleave, even_elems, odd_elems))
|
77
|
+
|
78
|
+
return _scan(elems)
|
79
|
+
|
80
|
+
def _interleave(a, b):
|
81
|
+
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
82
|
+
output_axis_len = a_axis_len + b_axis_len
|
83
|
+
|
84
|
+
if (a_axis_len == (b_axis_len + 1)):
|
85
|
+
b = pad_at_dim(b, (0, 1), dim = 1)
|
86
|
+
|
87
|
+
stacked = torch.stack([a, b], dim=2)
|
88
|
+
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
|
+
|
90
|
+
return interleaved[:, :output_axis_len]
|
@@ -0,0 +1,408 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad_and_value
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
class MLP(Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
dim,
|
64
|
+
depth
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
68
|
+
|
69
|
+
def forward(
|
70
|
+
self,
|
71
|
+
x
|
72
|
+
):
|
73
|
+
for ind, weight in enumerate(self.weights):
|
74
|
+
is_first = ind == 0
|
75
|
+
|
76
|
+
if not is_first:
|
77
|
+
x = F.silu(x)
|
78
|
+
|
79
|
+
x = x @ weight
|
80
|
+
|
81
|
+
return x
|
82
|
+
|
83
|
+
# main neural memory
|
84
|
+
|
85
|
+
def default_loss_fn(pred, target):
|
86
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
|
+
|
88
|
+
class NeuralMemory(Module):
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
dim,
|
92
|
+
chunk_size = 1,
|
93
|
+
dim_head = None,
|
94
|
+
heads = 1,
|
95
|
+
model: Module | None = None,
|
96
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
97
|
+
pre_rmsnorm = True,
|
98
|
+
post_rmsnorm = True,
|
99
|
+
use_accelerated_scan = False,
|
100
|
+
default_mlp_kwargs: dict = dict(
|
101
|
+
depth = 4
|
102
|
+
)
|
103
|
+
):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
# norms
|
107
|
+
|
108
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
109
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
110
|
+
|
111
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
112
|
+
|
113
|
+
# maybe multi-headed
|
114
|
+
|
115
|
+
dim_head = default(dim_head, dim)
|
116
|
+
dim_inner = dim_head * heads
|
117
|
+
|
118
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
119
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
120
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
121
|
+
|
122
|
+
# memory mlp
|
123
|
+
|
124
|
+
if not exists(model):
|
125
|
+
model = MLP(dim_head, **default_mlp_kwargs)
|
126
|
+
|
127
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
|
+
|
129
|
+
# the memory is the weights of the model
|
130
|
+
|
131
|
+
self.memory_model = model
|
132
|
+
|
133
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
134
|
+
|
135
|
+
self.chunk_size = chunk_size
|
136
|
+
|
137
|
+
# prepare function for per sample gradients from model above, using torch.func
|
138
|
+
|
139
|
+
def forward_and_loss(params, inputs, target):
|
140
|
+
pred = functional_call(self.memory_model, params, inputs)
|
141
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
|
+
return loss
|
143
|
+
|
144
|
+
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
145
|
+
|
146
|
+
# queries for retrieving from the model
|
147
|
+
|
148
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
149
|
+
|
150
|
+
# keys and values for storing to the model
|
151
|
+
|
152
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
154
|
+
|
155
|
+
# learned adaptive learning rate and momentum
|
156
|
+
# todo - explore mlp layerwise learned lr / momentum
|
157
|
+
|
158
|
+
self.to_momentum = nn.Sequential(
|
159
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
160
|
+
LinearNoBias(dim, heads),
|
161
|
+
Rearrange('b n h -> (b h) n 1')
|
162
|
+
)
|
163
|
+
|
164
|
+
self.to_adaptive_step = nn.Sequential(
|
165
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
166
|
+
LinearNoBias(dim, heads),
|
167
|
+
Rearrange('b n h -> (b h) n')
|
168
|
+
)
|
169
|
+
|
170
|
+
# weight decay factor
|
171
|
+
|
172
|
+
self.to_decay_factor = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n 1')
|
176
|
+
)
|
177
|
+
|
178
|
+
# maybe use accelerated scan
|
179
|
+
|
180
|
+
self.use_accelerated_scan = use_accelerated_scan
|
181
|
+
|
182
|
+
def init_weights_and_momentum(self):
|
183
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
184
|
+
|
185
|
+
init_weights = params.clone().zero_()
|
186
|
+
init_momentum = params.clone().zero_()
|
187
|
+
|
188
|
+
return init_weights, init_momentum
|
189
|
+
|
190
|
+
def store_memories(
|
191
|
+
self,
|
192
|
+
seq,
|
193
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
194
|
+
):
|
195
|
+
|
196
|
+
seq = self.store_norm(seq)
|
197
|
+
|
198
|
+
# curtail sequence by multiple of the chunk size
|
199
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
200
|
+
|
201
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
202
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
203
|
+
|
204
|
+
seq = seq[:, :round_down_seq_len]
|
205
|
+
|
206
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
207
|
+
|
208
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
209
|
+
|
210
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
211
|
+
past_weights, past_momentum = past_state
|
212
|
+
|
213
|
+
curr_weights = curr_weights + past_weights
|
214
|
+
|
215
|
+
# pack batch and sequence dimension
|
216
|
+
|
217
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
218
|
+
|
219
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
220
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
221
|
+
|
222
|
+
# keys and values
|
223
|
+
|
224
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
225
|
+
|
226
|
+
# maybe multi head
|
227
|
+
|
228
|
+
keys, values = map(self.split_heads, (keys, values))
|
229
|
+
|
230
|
+
batch = keys.shape[0]
|
231
|
+
|
232
|
+
# take care of chunking
|
233
|
+
|
234
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
235
|
+
|
236
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
|
+
|
238
|
+
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
239
|
+
|
240
|
+
grads = TensorDict(grads)
|
241
|
+
|
242
|
+
# restore batch and sequence dimension
|
243
|
+
|
244
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
245
|
+
|
246
|
+
# multiply gradients with learned adaptive step size
|
247
|
+
|
248
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
249
|
+
|
250
|
+
# determine scan function
|
251
|
+
|
252
|
+
def default_associative_scan(gates, inputs):
|
253
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
254
|
+
return outputs
|
255
|
+
|
256
|
+
if self.use_accelerated_scan:
|
257
|
+
from accelerated_scan.triton import scan as triton_scan
|
258
|
+
from accelerated_scan.warp import scan as warp_scan
|
259
|
+
|
260
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
261
|
+
|
262
|
+
def accelerate_scan_fn(gates, inputs):
|
263
|
+
gates = gates.expand_as(inputs)
|
264
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
265
|
+
|
266
|
+
seq_len = gates.shape[-1]
|
267
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
268
|
+
|
269
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
|
+
|
272
|
+
outputs = scan(gates, inputs)
|
273
|
+
|
274
|
+
outputs = outputs[..., :seq_len]
|
275
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
276
|
+
return outputs
|
277
|
+
|
278
|
+
scan_fn = accelerate_scan_fn
|
279
|
+
else:
|
280
|
+
scan_fn = default_associative_scan
|
281
|
+
|
282
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
283
|
+
|
284
|
+
next_momentum = TensorDict()
|
285
|
+
updates = TensorDict()
|
286
|
+
|
287
|
+
for param_name, surprise in surprises.items():
|
288
|
+
|
289
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
290
|
+
|
291
|
+
# derive momentum with associative scan - eq (10)
|
292
|
+
|
293
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
294
|
+
|
295
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
296
|
+
|
297
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
298
|
+
|
299
|
+
updates[param_name] = inverse_pack(update)
|
300
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
301
|
+
|
302
|
+
# compute the next weight per batch
|
303
|
+
|
304
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
305
|
+
|
306
|
+
next_state = (curr_weights + last_update, next_momentum)
|
307
|
+
|
308
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
309
|
+
|
310
|
+
def retrieve_memories(
|
311
|
+
self,
|
312
|
+
seq,
|
313
|
+
past_weights: dict[str, Tensor] | None = None,
|
314
|
+
):
|
315
|
+
chunk_size = self.chunk_size
|
316
|
+
seq_len = seq.shape[1]
|
317
|
+
|
318
|
+
seq = self.retrieve_norm(seq)
|
319
|
+
|
320
|
+
assert seq_len >= chunk_size
|
321
|
+
|
322
|
+
seq = seq[:, (chunk_size - 1):]
|
323
|
+
curtailed_seq_len = seq.shape[-2]
|
324
|
+
|
325
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
326
|
+
|
327
|
+
padding = next_seq_len - curtailed_seq_len
|
328
|
+
|
329
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
330
|
+
|
331
|
+
# the parameters of the memory model stores the memories of the key / values
|
332
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
333
|
+
|
334
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
335
|
+
|
336
|
+
if exists(past_weights):
|
337
|
+
past_weights = TensorDict(past_weights)
|
338
|
+
assert past_weights.keys() == curr_weights.keys()
|
339
|
+
|
340
|
+
curr_weights = curr_weights + past_weights
|
341
|
+
|
342
|
+
# sequence Float['b n d'] to queries
|
343
|
+
|
344
|
+
queries = self.to_queries(seq)
|
345
|
+
|
346
|
+
# maybe multihead
|
347
|
+
|
348
|
+
queries = self.split_heads(queries)
|
349
|
+
|
350
|
+
batch = queries.shape[0]
|
351
|
+
|
352
|
+
# fetch values from memory model
|
353
|
+
|
354
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
355
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
356
|
+
|
357
|
+
# forward functional call
|
358
|
+
|
359
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
360
|
+
|
361
|
+
# reconstitute batch dimension
|
362
|
+
|
363
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
364
|
+
|
365
|
+
# maybe merge heads and combine
|
366
|
+
|
367
|
+
values = self.merge_heads(values)
|
368
|
+
|
369
|
+
values = self.combine_heads(values)
|
370
|
+
|
371
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
372
|
+
|
373
|
+
values = self.post_rmsnorm(values)
|
374
|
+
|
375
|
+
# restore
|
376
|
+
|
377
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
378
|
+
values = values[:, :-padding]
|
379
|
+
|
380
|
+
return values
|
381
|
+
|
382
|
+
def forward(
|
383
|
+
self,
|
384
|
+
seq,
|
385
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
|
+
return_next_memories = False
|
387
|
+
):
|
388
|
+
batch, seq_len = seq.shape[:2]
|
389
|
+
|
390
|
+
if seq_len < self.chunk_size:
|
391
|
+
return torch.zeros_like(seq)
|
392
|
+
|
393
|
+
if exists(past_state):
|
394
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
395
|
+
|
396
|
+
if not exists(past_state):
|
397
|
+
past_state = self.init_weights_and_momentum()
|
398
|
+
|
399
|
+
updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
|
400
|
+
|
401
|
+
past_weights, _ = past_state
|
402
|
+
|
403
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
404
|
+
|
405
|
+
if not return_next_memories:
|
406
|
+
return retrieved
|
407
|
+
|
408
|
+
return retrieved, next_memories, aux_kv_mse_loss
|
@@ -0,0 +1,151 @@
|
|
1
|
+
import random
|
2
|
+
import tqdm
|
3
|
+
import gzip
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from torch.optim import Adam
|
9
|
+
from torch.nn import functional as F
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
11
|
+
|
12
|
+
from local_attention import LocalTransformer
|
13
|
+
|
14
|
+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
15
|
+
|
16
|
+
from titans_pytorch.titans import NeuralMemory
|
17
|
+
|
18
|
+
# constants
|
19
|
+
|
20
|
+
NUM_BATCHES = int(1e5)
|
21
|
+
BATCH_SIZE = 4
|
22
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
23
|
+
LEARNING_RATE = 2e-4
|
24
|
+
VALIDATE_EVERY = 100
|
25
|
+
GENERATE_EVERY = 500
|
26
|
+
GENERATE_LENGTH = 512
|
27
|
+
SHOULD_GENERATE = False
|
28
|
+
SEQ_LEN = 512
|
29
|
+
|
30
|
+
PROJECT_NAME = 'titans-neural-memory'
|
31
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
32
|
+
GLOBAL_LAYERS = (4, 5)
|
33
|
+
USE_TITANS_MEMORY = True
|
34
|
+
NEURAL_MEMORY_DEPTH = 2
|
35
|
+
WINDOW_SIZE = 64
|
36
|
+
RUN_NAME = 'neural memory'
|
37
|
+
|
38
|
+
# wandb experiment tracker
|
39
|
+
|
40
|
+
import wandb
|
41
|
+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
42
|
+
wandb.run.name = RUN_NAME
|
43
|
+
wandb.run.save()
|
44
|
+
|
45
|
+
# helpers
|
46
|
+
|
47
|
+
def cycle(loader):
|
48
|
+
while True:
|
49
|
+
for data in loader:
|
50
|
+
yield data
|
51
|
+
|
52
|
+
def decode_token(token):
|
53
|
+
return str(chr(max(32, token)))
|
54
|
+
|
55
|
+
def decode_tokens(tokens):
|
56
|
+
return ''.join(list(map(decode_token, tokens)))
|
57
|
+
|
58
|
+
# instantiate GPT-like decoder model
|
59
|
+
|
60
|
+
titans_neural_memory = NeuralMemory(
|
61
|
+
dim = 384,
|
62
|
+
chunk_size = WINDOW_SIZE,
|
63
|
+
pre_rmsnorm = True,
|
64
|
+
post_rmsnorm = True,
|
65
|
+
dim_head = 32,
|
66
|
+
heads = 8,
|
67
|
+
use_accelerated_scan = True,
|
68
|
+
default_mlp_kwargs = dict(
|
69
|
+
depth = NEURAL_MEMORY_DEPTH
|
70
|
+
)
|
71
|
+
)
|
72
|
+
|
73
|
+
linear_attn = TaylorSeriesLinearAttn(
|
74
|
+
dim = 384,
|
75
|
+
dim_head = 16,
|
76
|
+
heads = 16,
|
77
|
+
causal = True,
|
78
|
+
prenorm = True
|
79
|
+
)
|
80
|
+
|
81
|
+
model = LocalTransformer(
|
82
|
+
num_tokens = 256,
|
83
|
+
dim = 384,
|
84
|
+
depth = 8,
|
85
|
+
causal = True,
|
86
|
+
local_attn_window_size = WINDOW_SIZE,
|
87
|
+
max_seq_len = SEQ_LEN,
|
88
|
+
global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
|
89
|
+
layers_insert_global_attn = GLOBAL_LAYERS
|
90
|
+
).cuda()
|
91
|
+
|
92
|
+
# prepare enwik8 data
|
93
|
+
|
94
|
+
with gzip.open('./data/enwik8.gz') as file:
|
95
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
96
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
97
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
98
|
+
|
99
|
+
class TextSamplerDataset(Dataset):
|
100
|
+
def __init__(self, data, seq_len):
|
101
|
+
super().__init__()
|
102
|
+
self.data = data
|
103
|
+
self.seq_len = seq_len
|
104
|
+
|
105
|
+
def __getitem__(self, index):
|
106
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
107
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
108
|
+
return full_seq.cuda()
|
109
|
+
|
110
|
+
def __len__(self):
|
111
|
+
return self.data.size(0) // self.seq_len
|
112
|
+
|
113
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
114
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
115
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
116
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
117
|
+
|
118
|
+
# optimizer
|
119
|
+
|
120
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
121
|
+
|
122
|
+
# training
|
123
|
+
|
124
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
125
|
+
model.train()
|
126
|
+
|
127
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
128
|
+
loss = model(next(train_loader), return_loss = True)
|
129
|
+
loss.backward()
|
130
|
+
|
131
|
+
print(f'training loss: {loss.item()}')
|
132
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
133
|
+
optim.step()
|
134
|
+
optim.zero_grad()
|
135
|
+
wandb.log(dict(loss = loss.item()))
|
136
|
+
|
137
|
+
if i % VALIDATE_EVERY == 0:
|
138
|
+
model.eval()
|
139
|
+
with torch.no_grad():
|
140
|
+
loss = model(next(val_loader), return_loss = True)
|
141
|
+
print(f'validation loss: {loss.item()}')
|
142
|
+
|
143
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
144
|
+
model.eval()
|
145
|
+
inp = random.choice(val_dataset)[:-1]
|
146
|
+
prime = decode_tokens(inp)
|
147
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
148
|
+
|
149
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
150
|
+
output_str = decode_tokens(sample[0])
|
151
|
+
print(output_str)
|