titans-pytorch 0.0.14__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch-0.0.14/.github/workflows/python-publish.yml +36 -0
- titans_pytorch-0.0.14/.gitignore +173 -0
- titans_pytorch-0.0.14/LICENSE +21 -0
- titans_pytorch-0.0.14/PKG-INFO +111 -0
- titans_pytorch-0.0.14/README.md +60 -0
- titans_pytorch-0.0.14/data/README.md +3 -0
- titans_pytorch-0.0.14/data/enwik8.gz +0 -0
- titans_pytorch-0.0.14/fig1.png +0 -0
- titans_pytorch-0.0.14/fig2.png +0 -0
- titans_pytorch-0.0.14/pyproject.toml +70 -0
- titans_pytorch-0.0.14/requirements.txt +1 -0
- titans_pytorch-0.0.14/titans_pytorch/__init__.py +3 -0
- titans_pytorch-0.0.14/titans_pytorch/associative_scan.py +90 -0
- titans_pytorch-0.0.14/titans_pytorch/titans.py +408 -0
- titans_pytorch-0.0.14/train.py +151 -0
@@ -0,0 +1,36 @@
|
|
1
|
+
# This workflow will upload a Python Package using Twine when a release is created
|
2
|
+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
|
3
|
+
|
4
|
+
# This workflow uses actions that are not certified by GitHub.
|
5
|
+
# They are provided by a third-party and are governed by
|
6
|
+
# separate terms of service, privacy policy, and support
|
7
|
+
# documentation.
|
8
|
+
|
9
|
+
name: Upload Python Package
|
10
|
+
|
11
|
+
on:
|
12
|
+
release:
|
13
|
+
types: [published]
|
14
|
+
|
15
|
+
jobs:
|
16
|
+
deploy:
|
17
|
+
|
18
|
+
runs-on: ubuntu-latest
|
19
|
+
|
20
|
+
steps:
|
21
|
+
- uses: actions/checkout@v2
|
22
|
+
- name: Set up Python
|
23
|
+
uses: actions/setup-python@v2
|
24
|
+
with:
|
25
|
+
python-version: '3.x'
|
26
|
+
- name: Install dependencies
|
27
|
+
run: |
|
28
|
+
python -m pip install --upgrade pip
|
29
|
+
pip install build
|
30
|
+
- name: Build package
|
31
|
+
run: python -m build
|
32
|
+
- name: Publish package
|
33
|
+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
34
|
+
with:
|
35
|
+
user: __token__
|
36
|
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
@@ -0,0 +1,173 @@
|
|
1
|
+
train_local.py
|
2
|
+
|
3
|
+
# Byte-compiled / optimized / DLL files
|
4
|
+
__pycache__/
|
5
|
+
*.py[cod]
|
6
|
+
*$py.class
|
7
|
+
|
8
|
+
# C extensions
|
9
|
+
*.so
|
10
|
+
|
11
|
+
# Distribution / packaging
|
12
|
+
.Python
|
13
|
+
build/
|
14
|
+
develop-eggs/
|
15
|
+
dist/
|
16
|
+
downloads/
|
17
|
+
eggs/
|
18
|
+
.eggs/
|
19
|
+
lib/
|
20
|
+
lib64/
|
21
|
+
parts/
|
22
|
+
sdist/
|
23
|
+
var/
|
24
|
+
wheels/
|
25
|
+
share/python-wheels/
|
26
|
+
*.egg-info/
|
27
|
+
.installed.cfg
|
28
|
+
*.egg
|
29
|
+
MANIFEST
|
30
|
+
|
31
|
+
# PyInstaller
|
32
|
+
# Usually these files are written by a python script from a template
|
33
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34
|
+
*.manifest
|
35
|
+
*.spec
|
36
|
+
|
37
|
+
# Installer logs
|
38
|
+
pip-log.txt
|
39
|
+
pip-delete-this-directory.txt
|
40
|
+
|
41
|
+
# Unit test / coverage reports
|
42
|
+
htmlcov/
|
43
|
+
.tox/
|
44
|
+
.nox/
|
45
|
+
.coverage
|
46
|
+
.coverage.*
|
47
|
+
.cache
|
48
|
+
nosetests.xml
|
49
|
+
coverage.xml
|
50
|
+
*.cover
|
51
|
+
*.py,cover
|
52
|
+
.hypothesis/
|
53
|
+
.pytest_cache/
|
54
|
+
cover/
|
55
|
+
|
56
|
+
# Translations
|
57
|
+
*.mo
|
58
|
+
*.pot
|
59
|
+
|
60
|
+
# Django stuff:
|
61
|
+
*.log
|
62
|
+
local_settings.py
|
63
|
+
db.sqlite3
|
64
|
+
db.sqlite3-journal
|
65
|
+
|
66
|
+
# Flask stuff:
|
67
|
+
instance/
|
68
|
+
.webassets-cache
|
69
|
+
|
70
|
+
# Scrapy stuff:
|
71
|
+
.scrapy
|
72
|
+
|
73
|
+
# Sphinx documentation
|
74
|
+
docs/_build/
|
75
|
+
|
76
|
+
# PyBuilder
|
77
|
+
.pybuilder/
|
78
|
+
target/
|
79
|
+
|
80
|
+
# Jupyter Notebook
|
81
|
+
.ipynb_checkpoints
|
82
|
+
|
83
|
+
# IPython
|
84
|
+
profile_default/
|
85
|
+
ipython_config.py
|
86
|
+
|
87
|
+
# pyenv
|
88
|
+
# For a library or package, you might want to ignore these files since the code is
|
89
|
+
# intended to run in multiple environments; otherwise, check them in:
|
90
|
+
# .python-version
|
91
|
+
|
92
|
+
# pipenv
|
93
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96
|
+
# install all needed dependencies.
|
97
|
+
#Pipfile.lock
|
98
|
+
|
99
|
+
# UV
|
100
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
101
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102
|
+
# commonly ignored for libraries.
|
103
|
+
#uv.lock
|
104
|
+
|
105
|
+
# poetry
|
106
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
107
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
108
|
+
# commonly ignored for libraries.
|
109
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
110
|
+
#poetry.lock
|
111
|
+
|
112
|
+
# pdm
|
113
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
114
|
+
#pdm.lock
|
115
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
116
|
+
# in version control.
|
117
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
118
|
+
.pdm.toml
|
119
|
+
.pdm-python
|
120
|
+
.pdm-build/
|
121
|
+
|
122
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
123
|
+
__pypackages__/
|
124
|
+
|
125
|
+
# Celery stuff
|
126
|
+
celerybeat-schedule
|
127
|
+
celerybeat.pid
|
128
|
+
|
129
|
+
# SageMath parsed files
|
130
|
+
*.sage.py
|
131
|
+
|
132
|
+
# Environments
|
133
|
+
.env
|
134
|
+
.venv
|
135
|
+
env/
|
136
|
+
venv/
|
137
|
+
ENV/
|
138
|
+
env.bak/
|
139
|
+
venv.bak/
|
140
|
+
|
141
|
+
# Spyder project settings
|
142
|
+
.spyderproject
|
143
|
+
.spyproject
|
144
|
+
|
145
|
+
# Rope project settings
|
146
|
+
.ropeproject
|
147
|
+
|
148
|
+
# mkdocs documentation
|
149
|
+
/site
|
150
|
+
|
151
|
+
# mypy
|
152
|
+
.mypy_cache/
|
153
|
+
.dmypy.json
|
154
|
+
dmypy.json
|
155
|
+
|
156
|
+
# Pyre type checker
|
157
|
+
.pyre/
|
158
|
+
|
159
|
+
# pytype static type analyzer
|
160
|
+
.pytype/
|
161
|
+
|
162
|
+
# Cython debug symbols
|
163
|
+
cython_debug/
|
164
|
+
|
165
|
+
# PyCharm
|
166
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
167
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
168
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
169
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
170
|
+
#.idea/
|
171
|
+
|
172
|
+
# PyPI configuration file
|
173
|
+
.pypirc
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2025 Phil Wang
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,111 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: titans-pytorch
|
3
|
+
Version: 0.0.14
|
4
|
+
Summary: Titans
|
5
|
+
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
|
+
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
7
|
+
Author-email: Phil Wang <lucidrains@gmail.com>
|
8
|
+
License: MIT License
|
9
|
+
|
10
|
+
Copyright (c) 2025 Phil Wang
|
11
|
+
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
14
|
+
in the Software without restriction, including without limitation the rights
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
17
|
+
furnished to do so, subject to the following conditions:
|
18
|
+
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
20
|
+
copies or substantial portions of the Software.
|
21
|
+
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
|
+
SOFTWARE.
|
29
|
+
License-File: LICENSE
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
|
31
|
+
Classifier: Development Status :: 4 - Beta
|
32
|
+
Classifier: Intended Audience :: Developers
|
33
|
+
Classifier: License :: OSI Approved :: MIT License
|
34
|
+
Classifier: Programming Language :: Python :: 3.9
|
35
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
|
+
Requires-Python: >=3.9
|
37
|
+
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: einops>=0.8.0
|
39
|
+
Requires-Dist: einx>=0.3.0
|
40
|
+
Requires-Dist: ninja
|
41
|
+
Requires-Dist: tensordict
|
42
|
+
Requires-Dist: torch>=2.2
|
43
|
+
Provides-Extra: examples
|
44
|
+
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
45
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
46
|
+
Requires-Dist: tqdm; extra == 'examples'
|
47
|
+
Requires-Dist: wandb; extra == 'examples'
|
48
|
+
Provides-Extra: test
|
49
|
+
Requires-Dist: pytest; extra == 'test'
|
50
|
+
Description-Content-Type: text/markdown
|
51
|
+
|
52
|
+
<img src="./fig2.png" width="400px"></img>
|
53
|
+
|
54
|
+
<img src="./fig1.png" width="400px"></img>
|
55
|
+
|
56
|
+
## Titans - Pytorch (wip)
|
57
|
+
|
58
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
59
|
+
|
60
|
+
## Install
|
61
|
+
|
62
|
+
```bash
|
63
|
+
$ pip install titans-pytorch
|
64
|
+
```
|
65
|
+
|
66
|
+
## Usage
|
67
|
+
|
68
|
+
```python
|
69
|
+
import torch
|
70
|
+
from titans_pytorch import NeuralMemory
|
71
|
+
|
72
|
+
mem = NeuralMemory(
|
73
|
+
dim = 384,
|
74
|
+
chunk_size = 64,
|
75
|
+
pre_rmsnorm = True
|
76
|
+
).cuda()
|
77
|
+
|
78
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
79
|
+
retrieved = mem(seq)
|
80
|
+
|
81
|
+
assert seq.shape == retrieved.shape
|
82
|
+
```
|
83
|
+
|
84
|
+
## Experiments
|
85
|
+
|
86
|
+
```bash
|
87
|
+
$ pip install .[examples]
|
88
|
+
```
|
89
|
+
|
90
|
+
For the SOTA linear attention, you will also need to run
|
91
|
+
|
92
|
+
```bash
|
93
|
+
$ pip install -r requirements.txt
|
94
|
+
```
|
95
|
+
|
96
|
+
Then modify `train.py` and run it to query nature
|
97
|
+
|
98
|
+
```bash
|
99
|
+
$ python train.py
|
100
|
+
```
|
101
|
+
|
102
|
+
## Citations
|
103
|
+
|
104
|
+
```bibtex
|
105
|
+
@inproceedings{Behrouz2024TitansLT,
|
106
|
+
title = {Titans: Learning to Memorize at Test Time},
|
107
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
108
|
+
year = {2024},
|
109
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
110
|
+
}
|
111
|
+
```
|
@@ -0,0 +1,60 @@
|
|
1
|
+
<img src="./fig2.png" width="400px"></img>
|
2
|
+
|
3
|
+
<img src="./fig1.png" width="400px"></img>
|
4
|
+
|
5
|
+
## Titans - Pytorch (wip)
|
6
|
+
|
7
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
|
+
|
9
|
+
## Install
|
10
|
+
|
11
|
+
```bash
|
12
|
+
$ pip install titans-pytorch
|
13
|
+
```
|
14
|
+
|
15
|
+
## Usage
|
16
|
+
|
17
|
+
```python
|
18
|
+
import torch
|
19
|
+
from titans_pytorch import NeuralMemory
|
20
|
+
|
21
|
+
mem = NeuralMemory(
|
22
|
+
dim = 384,
|
23
|
+
chunk_size = 64,
|
24
|
+
pre_rmsnorm = True
|
25
|
+
).cuda()
|
26
|
+
|
27
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
28
|
+
retrieved = mem(seq)
|
29
|
+
|
30
|
+
assert seq.shape == retrieved.shape
|
31
|
+
```
|
32
|
+
|
33
|
+
## Experiments
|
34
|
+
|
35
|
+
```bash
|
36
|
+
$ pip install .[examples]
|
37
|
+
```
|
38
|
+
|
39
|
+
For the SOTA linear attention, you will also need to run
|
40
|
+
|
41
|
+
```bash
|
42
|
+
$ pip install -r requirements.txt
|
43
|
+
```
|
44
|
+
|
45
|
+
Then modify `train.py` and run it to query nature
|
46
|
+
|
47
|
+
```bash
|
48
|
+
$ python train.py
|
49
|
+
```
|
50
|
+
|
51
|
+
## Citations
|
52
|
+
|
53
|
+
```bibtex
|
54
|
+
@inproceedings{Behrouz2024TitansLT,
|
55
|
+
title = {Titans: Learning to Memorize at Test Time},
|
56
|
+
author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
|
57
|
+
year = {2024},
|
58
|
+
url = {https://api.semanticscholar.org/CorpusID:275212078}
|
59
|
+
}
|
60
|
+
```
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,70 @@
|
|
1
|
+
[project]
|
2
|
+
name = "titans-pytorch"
|
3
|
+
version = "0.0.14"
|
4
|
+
description = "Titans"
|
5
|
+
authors = [
|
6
|
+
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
7
|
+
]
|
8
|
+
readme = "README.md"
|
9
|
+
requires-python = ">= 3.9"
|
10
|
+
license = { file = "LICENSE" }
|
11
|
+
keywords = [
|
12
|
+
'artificial intelligence',
|
13
|
+
'deep learning',
|
14
|
+
'neural memory module',
|
15
|
+
'test time training',
|
16
|
+
'linear attention'
|
17
|
+
]
|
18
|
+
|
19
|
+
classifiers=[
|
20
|
+
'Development Status :: 4 - Beta',
|
21
|
+
'Intended Audience :: Developers',
|
22
|
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
23
|
+
'License :: OSI Approved :: MIT License',
|
24
|
+
'Programming Language :: Python :: 3.9',
|
25
|
+
]
|
26
|
+
|
27
|
+
dependencies = [
|
28
|
+
"accelerated-scan>=0.2.0",
|
29
|
+
"einx>=0.3.0",
|
30
|
+
"einops>=0.8.0",
|
31
|
+
"Ninja",
|
32
|
+
"tensordict",
|
33
|
+
"torch>=2.2",
|
34
|
+
]
|
35
|
+
|
36
|
+
[project.urls]
|
37
|
+
Homepage = "https://pypi.org/project/titans-pytorch/"
|
38
|
+
Repository = "https://github.com/lucidrains/titans-pytorch"
|
39
|
+
|
40
|
+
[project.optional-dependencies]
|
41
|
+
|
42
|
+
examples = [
|
43
|
+
"local-attention>=1.10.1",
|
44
|
+
"taylor-series-linear-attention",
|
45
|
+
"tqdm",
|
46
|
+
"wandb"
|
47
|
+
]
|
48
|
+
|
49
|
+
test = [
|
50
|
+
"pytest"
|
51
|
+
]
|
52
|
+
|
53
|
+
[tool.pytest.ini_options]
|
54
|
+
pythonpath = [
|
55
|
+
"."
|
56
|
+
]
|
57
|
+
|
58
|
+
[build-system]
|
59
|
+
requires = ["hatchling"]
|
60
|
+
build-backend = "hatchling.build"
|
61
|
+
|
62
|
+
[tool.rye]
|
63
|
+
managed = true
|
64
|
+
dev-dependencies = []
|
65
|
+
|
66
|
+
[tool.hatch.metadata]
|
67
|
+
allow-direct-references = true
|
68
|
+
|
69
|
+
[tool.hatch.build.targets.wheel]
|
70
|
+
packages = ["titans_pytorch"]
|
@@ -0,0 +1 @@
|
|
1
|
+
pytorch-fast-transformers>=0.4.0
|
@@ -0,0 +1,90 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
from typing import Callable
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import Tensor
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
# taken from S5-pytorch repository
|
9
|
+
# https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
|
10
|
+
|
11
|
+
# helper functions
|
12
|
+
|
13
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
14
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
15
|
+
zeros = ((0, 0) * dims_from_right)
|
16
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
17
|
+
|
18
|
+
# the operator that is needed
|
19
|
+
|
20
|
+
@torch.jit.script
|
21
|
+
def binary_operator(
|
22
|
+
a: tuple[Tensor, Tensor],
|
23
|
+
b: tuple[Tensor, Tensor]
|
24
|
+
):
|
25
|
+
a_i, kv_i = a
|
26
|
+
a_j, kv_j = b
|
27
|
+
return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
|
28
|
+
|
29
|
+
# Pytorch impl. of jax.lax.associative_scan
|
30
|
+
# made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
|
31
|
+
|
32
|
+
def associative_scan(
|
33
|
+
operator: Callable,
|
34
|
+
elems: tuple[Tensor, Tensor]
|
35
|
+
):
|
36
|
+
num_elems = int(elems[0].shape[1])
|
37
|
+
|
38
|
+
if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
|
39
|
+
raise ValueError('Array inputs to associative_scan must have the same '
|
40
|
+
'first dimension. (saw: {})'
|
41
|
+
.format([elem.shape for elem in elems]))
|
42
|
+
|
43
|
+
def _scan(elems):
|
44
|
+
"""Perform scan on `elems`."""
|
45
|
+
num_elems = elems[0].shape[1]
|
46
|
+
|
47
|
+
if num_elems < 2:
|
48
|
+
return elems
|
49
|
+
|
50
|
+
# Combine adjacent pairs of elements.
|
51
|
+
|
52
|
+
reduced_elems = operator(
|
53
|
+
[elem[:, :-1:2] for elem in elems],
|
54
|
+
[elem[:, 1::2] for elem in elems])
|
55
|
+
|
56
|
+
# Recursively compute scan for partially reduced tensors.
|
57
|
+
|
58
|
+
odd_elems = _scan(reduced_elems)
|
59
|
+
|
60
|
+
if num_elems % 2 == 0:
|
61
|
+
even_elems = operator(
|
62
|
+
[e[:, :-1] for e in odd_elems],
|
63
|
+
[e[:, 2::2] for e in elems])
|
64
|
+
else:
|
65
|
+
even_elems = operator(
|
66
|
+
odd_elems,
|
67
|
+
[e[:, 2::2] for e in elems])
|
68
|
+
|
69
|
+
# The first element of a scan is the same as the first element
|
70
|
+
# of the original `elems`.
|
71
|
+
|
72
|
+
even_elems = [
|
73
|
+
torch.cat([elem[:, :1], result], dim=1)
|
74
|
+
for (elem, result) in zip(elems, even_elems)]
|
75
|
+
|
76
|
+
return list(map(_interleave, even_elems, odd_elems))
|
77
|
+
|
78
|
+
return _scan(elems)
|
79
|
+
|
80
|
+
def _interleave(a, b):
|
81
|
+
a_axis_len, b_axis_len = a.shape[1], b.shape[1]
|
82
|
+
output_axis_len = a_axis_len + b_axis_len
|
83
|
+
|
84
|
+
if (a_axis_len == (b_axis_len + 1)):
|
85
|
+
b = pad_at_dim(b, (0, 1), dim = 1)
|
86
|
+
|
87
|
+
stacked = torch.stack([a, b], dim=2)
|
88
|
+
interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
|
89
|
+
|
90
|
+
return interleaved[:, :output_axis_len]
|
@@ -0,0 +1,408 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
import math
|
3
|
+
from functools import partial
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn, Tensor
|
7
|
+
import torch.nn.functional as F
|
8
|
+
from torch.nn import Linear, Module
|
9
|
+
from torch.func import functional_call, vmap, grad_and_value
|
10
|
+
|
11
|
+
from tensordict import TensorDict
|
12
|
+
|
13
|
+
from titans_pytorch.associative_scan import (
|
14
|
+
associative_scan,
|
15
|
+
binary_operator,
|
16
|
+
pad_at_dim
|
17
|
+
)
|
18
|
+
|
19
|
+
import einx
|
20
|
+
from einops import rearrange, pack, unpack
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
22
|
+
|
23
|
+
"""
|
24
|
+
ein notation:
|
25
|
+
b - batch
|
26
|
+
n - sequence
|
27
|
+
d - feature dimension
|
28
|
+
c - intra-chunk
|
29
|
+
"""
|
30
|
+
|
31
|
+
# constants
|
32
|
+
|
33
|
+
LinearNoBias = partial(Linear, bias = False)
|
34
|
+
|
35
|
+
# functions
|
36
|
+
|
37
|
+
def exists(v):
|
38
|
+
return v is not None
|
39
|
+
|
40
|
+
def default(v, d):
|
41
|
+
return v if exists(v) else d
|
42
|
+
|
43
|
+
def round_down_multiple(seq, mult):
|
44
|
+
return seq // mult * mult
|
45
|
+
|
46
|
+
def round_up_multiple(seq, mult):
|
47
|
+
return math.ceil(seq / mult) * mult
|
48
|
+
|
49
|
+
def pack_one_with_inverse(t, pattern):
|
50
|
+
packed, packed_shape = pack([t], pattern)
|
51
|
+
|
52
|
+
def inverse(out, inv_pattern = None):
|
53
|
+
inv_pattern = default(inv_pattern, pattern)
|
54
|
+
return unpack(out, packed_shape, inv_pattern)[0]
|
55
|
+
|
56
|
+
return packed, inverse
|
57
|
+
|
58
|
+
# classes
|
59
|
+
|
60
|
+
class MLP(Module):
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
dim,
|
64
|
+
depth
|
65
|
+
):
|
66
|
+
super().__init__()
|
67
|
+
self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
|
68
|
+
|
69
|
+
def forward(
|
70
|
+
self,
|
71
|
+
x
|
72
|
+
):
|
73
|
+
for ind, weight in enumerate(self.weights):
|
74
|
+
is_first = ind == 0
|
75
|
+
|
76
|
+
if not is_first:
|
77
|
+
x = F.silu(x)
|
78
|
+
|
79
|
+
x = x @ weight
|
80
|
+
|
81
|
+
return x
|
82
|
+
|
83
|
+
# main neural memory
|
84
|
+
|
85
|
+
def default_loss_fn(pred, target):
|
86
|
+
return (pred - target).pow(2).mean(dim = -1).sum()
|
87
|
+
|
88
|
+
class NeuralMemory(Module):
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
dim,
|
92
|
+
chunk_size = 1,
|
93
|
+
dim_head = None,
|
94
|
+
heads = 1,
|
95
|
+
model: Module | None = None,
|
96
|
+
store_memory_loss_fn: Callable = default_loss_fn,
|
97
|
+
pre_rmsnorm = True,
|
98
|
+
post_rmsnorm = True,
|
99
|
+
use_accelerated_scan = False,
|
100
|
+
default_mlp_kwargs: dict = dict(
|
101
|
+
depth = 4
|
102
|
+
)
|
103
|
+
):
|
104
|
+
super().__init__()
|
105
|
+
|
106
|
+
# norms
|
107
|
+
|
108
|
+
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
109
|
+
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
110
|
+
|
111
|
+
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
112
|
+
|
113
|
+
# maybe multi-headed
|
114
|
+
|
115
|
+
dim_head = default(dim_head, dim)
|
116
|
+
dim_inner = dim_head * heads
|
117
|
+
|
118
|
+
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
119
|
+
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
120
|
+
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
121
|
+
|
122
|
+
# memory mlp
|
123
|
+
|
124
|
+
if not exists(model):
|
125
|
+
model = MLP(dim_head, **default_mlp_kwargs)
|
126
|
+
|
127
|
+
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
128
|
+
|
129
|
+
# the memory is the weights of the model
|
130
|
+
|
131
|
+
self.memory_model = model
|
132
|
+
|
133
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
134
|
+
|
135
|
+
self.chunk_size = chunk_size
|
136
|
+
|
137
|
+
# prepare function for per sample gradients from model above, using torch.func
|
138
|
+
|
139
|
+
def forward_and_loss(params, inputs, target):
|
140
|
+
pred = functional_call(self.memory_model, params, inputs)
|
141
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
142
|
+
return loss
|
143
|
+
|
144
|
+
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
145
|
+
|
146
|
+
# queries for retrieving from the model
|
147
|
+
|
148
|
+
self.to_queries = LinearNoBias(dim, dim_inner)
|
149
|
+
|
150
|
+
# keys and values for storing to the model
|
151
|
+
|
152
|
+
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
|
+
self.store_memory_loss_fn = store_memory_loss_fn
|
154
|
+
|
155
|
+
# learned adaptive learning rate and momentum
|
156
|
+
# todo - explore mlp layerwise learned lr / momentum
|
157
|
+
|
158
|
+
self.to_momentum = nn.Sequential(
|
159
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
160
|
+
LinearNoBias(dim, heads),
|
161
|
+
Rearrange('b n h -> (b h) n 1')
|
162
|
+
)
|
163
|
+
|
164
|
+
self.to_adaptive_step = nn.Sequential(
|
165
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
166
|
+
LinearNoBias(dim, heads),
|
167
|
+
Rearrange('b n h -> (b h) n')
|
168
|
+
)
|
169
|
+
|
170
|
+
# weight decay factor
|
171
|
+
|
172
|
+
self.to_decay_factor = nn.Sequential(
|
173
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
174
|
+
LinearNoBias(dim, heads),
|
175
|
+
Rearrange('b n h -> (b h) n 1')
|
176
|
+
)
|
177
|
+
|
178
|
+
# maybe use accelerated scan
|
179
|
+
|
180
|
+
self.use_accelerated_scan = use_accelerated_scan
|
181
|
+
|
182
|
+
def init_weights_and_momentum(self):
|
183
|
+
params = TensorDict(dict(self.memory_model.named_parameters()))
|
184
|
+
|
185
|
+
init_weights = params.clone().zero_()
|
186
|
+
init_momentum = params.clone().zero_()
|
187
|
+
|
188
|
+
return init_weights, init_momentum
|
189
|
+
|
190
|
+
def store_memories(
|
191
|
+
self,
|
192
|
+
seq,
|
193
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
194
|
+
):
|
195
|
+
|
196
|
+
seq = self.store_norm(seq)
|
197
|
+
|
198
|
+
# curtail sequence by multiple of the chunk size
|
199
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
200
|
+
|
201
|
+
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
202
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
203
|
+
|
204
|
+
seq = seq[:, :round_down_seq_len]
|
205
|
+
|
206
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
207
|
+
|
208
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
209
|
+
|
210
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
211
|
+
past_weights, past_momentum = past_state
|
212
|
+
|
213
|
+
curr_weights = curr_weights + past_weights
|
214
|
+
|
215
|
+
# pack batch and sequence dimension
|
216
|
+
|
217
|
+
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
218
|
+
|
219
|
+
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
220
|
+
decay_factor = self.to_decay_factor(seq).sigmoid()
|
221
|
+
|
222
|
+
# keys and values
|
223
|
+
|
224
|
+
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
225
|
+
|
226
|
+
# maybe multi head
|
227
|
+
|
228
|
+
keys, values = map(self.split_heads, (keys, values))
|
229
|
+
|
230
|
+
batch = keys.shape[0]
|
231
|
+
|
232
|
+
# take care of chunking
|
233
|
+
|
234
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
235
|
+
|
236
|
+
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
237
|
+
|
238
|
+
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
239
|
+
|
240
|
+
grads = TensorDict(grads)
|
241
|
+
|
242
|
+
# restore batch and sequence dimension
|
243
|
+
|
244
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
245
|
+
|
246
|
+
# multiply gradients with learned adaptive step size
|
247
|
+
|
248
|
+
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
249
|
+
|
250
|
+
# determine scan function
|
251
|
+
|
252
|
+
def default_associative_scan(gates, inputs):
|
253
|
+
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
254
|
+
return outputs
|
255
|
+
|
256
|
+
if self.use_accelerated_scan:
|
257
|
+
from accelerated_scan.triton import scan as triton_scan
|
258
|
+
from accelerated_scan.warp import scan as warp_scan
|
259
|
+
|
260
|
+
scan = triton_scan if seq.is_cuda else warp_scan
|
261
|
+
|
262
|
+
def accelerate_scan_fn(gates, inputs):
|
263
|
+
gates = gates.expand_as(inputs)
|
264
|
+
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
265
|
+
|
266
|
+
seq_len = gates.shape[-1]
|
267
|
+
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
268
|
+
|
269
|
+
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
270
|
+
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
271
|
+
|
272
|
+
outputs = scan(gates, inputs)
|
273
|
+
|
274
|
+
outputs = outputs[..., :seq_len]
|
275
|
+
outputs = rearrange(outputs, 'b d n -> b n d')
|
276
|
+
return outputs
|
277
|
+
|
278
|
+
scan_fn = accelerate_scan_fn
|
279
|
+
else:
|
280
|
+
scan_fn = default_associative_scan
|
281
|
+
|
282
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
283
|
+
|
284
|
+
next_momentum = TensorDict()
|
285
|
+
updates = TensorDict()
|
286
|
+
|
287
|
+
for param_name, surprise in surprises.items():
|
288
|
+
|
289
|
+
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
290
|
+
|
291
|
+
# derive momentum with associative scan - eq (10)
|
292
|
+
|
293
|
+
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
294
|
+
|
295
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
296
|
+
|
297
|
+
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
298
|
+
|
299
|
+
updates[param_name] = inverse_pack(update)
|
300
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
301
|
+
|
302
|
+
# compute the next weight per batch
|
303
|
+
|
304
|
+
last_update = updates.apply(lambda t: t[:, -1])
|
305
|
+
|
306
|
+
next_state = (curr_weights + last_update, next_momentum)
|
307
|
+
|
308
|
+
return updates, next_state, aux_store_loss.mean() / chunk_size
|
309
|
+
|
310
|
+
def retrieve_memories(
|
311
|
+
self,
|
312
|
+
seq,
|
313
|
+
past_weights: dict[str, Tensor] | None = None,
|
314
|
+
):
|
315
|
+
chunk_size = self.chunk_size
|
316
|
+
seq_len = seq.shape[1]
|
317
|
+
|
318
|
+
seq = self.retrieve_norm(seq)
|
319
|
+
|
320
|
+
assert seq_len >= chunk_size
|
321
|
+
|
322
|
+
seq = seq[:, (chunk_size - 1):]
|
323
|
+
curtailed_seq_len = seq.shape[-2]
|
324
|
+
|
325
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
326
|
+
|
327
|
+
padding = next_seq_len - curtailed_seq_len
|
328
|
+
|
329
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
330
|
+
|
331
|
+
# the parameters of the memory model stores the memories of the key / values
|
332
|
+
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
333
|
+
|
334
|
+
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
335
|
+
|
336
|
+
if exists(past_weights):
|
337
|
+
past_weights = TensorDict(past_weights)
|
338
|
+
assert past_weights.keys() == curr_weights.keys()
|
339
|
+
|
340
|
+
curr_weights = curr_weights + past_weights
|
341
|
+
|
342
|
+
# sequence Float['b n d'] to queries
|
343
|
+
|
344
|
+
queries = self.to_queries(seq)
|
345
|
+
|
346
|
+
# maybe multihead
|
347
|
+
|
348
|
+
queries = self.split_heads(queries)
|
349
|
+
|
350
|
+
batch = queries.shape[0]
|
351
|
+
|
352
|
+
# fetch values from memory model
|
353
|
+
|
354
|
+
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
355
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
356
|
+
|
357
|
+
# forward functional call
|
358
|
+
|
359
|
+
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
360
|
+
|
361
|
+
# reconstitute batch dimension
|
362
|
+
|
363
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
364
|
+
|
365
|
+
# maybe merge heads and combine
|
366
|
+
|
367
|
+
values = self.merge_heads(values)
|
368
|
+
|
369
|
+
values = self.combine_heads(values)
|
370
|
+
|
371
|
+
# post norm, somehow could not stabilize this without it, not in paper
|
372
|
+
|
373
|
+
values = self.post_rmsnorm(values)
|
374
|
+
|
375
|
+
# restore
|
376
|
+
|
377
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
378
|
+
values = values[:, :-padding]
|
379
|
+
|
380
|
+
return values
|
381
|
+
|
382
|
+
def forward(
|
383
|
+
self,
|
384
|
+
seq,
|
385
|
+
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
386
|
+
return_next_memories = False
|
387
|
+
):
|
388
|
+
batch, seq_len = seq.shape[:2]
|
389
|
+
|
390
|
+
if seq_len < self.chunk_size:
|
391
|
+
return torch.zeros_like(seq)
|
392
|
+
|
393
|
+
if exists(past_state):
|
394
|
+
past_state = tuple(TensorDict(d) for d in past_state)
|
395
|
+
|
396
|
+
if not exists(past_state):
|
397
|
+
past_state = self.init_weights_and_momentum()
|
398
|
+
|
399
|
+
updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
|
400
|
+
|
401
|
+
past_weights, _ = past_state
|
402
|
+
|
403
|
+
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
404
|
+
|
405
|
+
if not return_next_memories:
|
406
|
+
return retrieved
|
407
|
+
|
408
|
+
return retrieved, next_memories, aux_kv_mse_loss
|
@@ -0,0 +1,151 @@
|
|
1
|
+
import random
|
2
|
+
import tqdm
|
3
|
+
import gzip
|
4
|
+
import numpy as np
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from torch.optim import Adam
|
9
|
+
from torch.nn import functional as F
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
11
|
+
|
12
|
+
from local_attention import LocalTransformer
|
13
|
+
|
14
|
+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
15
|
+
|
16
|
+
from titans_pytorch.titans import NeuralMemory
|
17
|
+
|
18
|
+
# constants
|
19
|
+
|
20
|
+
NUM_BATCHES = int(1e5)
|
21
|
+
BATCH_SIZE = 4
|
22
|
+
GRADIENT_ACCUMULATE_EVERY = 4
|
23
|
+
LEARNING_RATE = 2e-4
|
24
|
+
VALIDATE_EVERY = 100
|
25
|
+
GENERATE_EVERY = 500
|
26
|
+
GENERATE_LENGTH = 512
|
27
|
+
SHOULD_GENERATE = False
|
28
|
+
SEQ_LEN = 512
|
29
|
+
|
30
|
+
PROJECT_NAME = 'titans-neural-memory'
|
31
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
32
|
+
GLOBAL_LAYERS = (4, 5)
|
33
|
+
USE_TITANS_MEMORY = True
|
34
|
+
NEURAL_MEMORY_DEPTH = 2
|
35
|
+
WINDOW_SIZE = 64
|
36
|
+
RUN_NAME = 'neural memory'
|
37
|
+
|
38
|
+
# wandb experiment tracker
|
39
|
+
|
40
|
+
import wandb
|
41
|
+
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
42
|
+
wandb.run.name = RUN_NAME
|
43
|
+
wandb.run.save()
|
44
|
+
|
45
|
+
# helpers
|
46
|
+
|
47
|
+
def cycle(loader):
|
48
|
+
while True:
|
49
|
+
for data in loader:
|
50
|
+
yield data
|
51
|
+
|
52
|
+
def decode_token(token):
|
53
|
+
return str(chr(max(32, token)))
|
54
|
+
|
55
|
+
def decode_tokens(tokens):
|
56
|
+
return ''.join(list(map(decode_token, tokens)))
|
57
|
+
|
58
|
+
# instantiate GPT-like decoder model
|
59
|
+
|
60
|
+
titans_neural_memory = NeuralMemory(
|
61
|
+
dim = 384,
|
62
|
+
chunk_size = WINDOW_SIZE,
|
63
|
+
pre_rmsnorm = True,
|
64
|
+
post_rmsnorm = True,
|
65
|
+
dim_head = 32,
|
66
|
+
heads = 8,
|
67
|
+
use_accelerated_scan = True,
|
68
|
+
default_mlp_kwargs = dict(
|
69
|
+
depth = NEURAL_MEMORY_DEPTH
|
70
|
+
)
|
71
|
+
)
|
72
|
+
|
73
|
+
linear_attn = TaylorSeriesLinearAttn(
|
74
|
+
dim = 384,
|
75
|
+
dim_head = 16,
|
76
|
+
heads = 16,
|
77
|
+
causal = True,
|
78
|
+
prenorm = True
|
79
|
+
)
|
80
|
+
|
81
|
+
model = LocalTransformer(
|
82
|
+
num_tokens = 256,
|
83
|
+
dim = 384,
|
84
|
+
depth = 8,
|
85
|
+
causal = True,
|
86
|
+
local_attn_window_size = WINDOW_SIZE,
|
87
|
+
max_seq_len = SEQ_LEN,
|
88
|
+
global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
|
89
|
+
layers_insert_global_attn = GLOBAL_LAYERS
|
90
|
+
).cuda()
|
91
|
+
|
92
|
+
# prepare enwik8 data
|
93
|
+
|
94
|
+
with gzip.open('./data/enwik8.gz') as file:
|
95
|
+
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
96
|
+
data_train, data_val = np.split(data, [int(90e6)])
|
97
|
+
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
98
|
+
|
99
|
+
class TextSamplerDataset(Dataset):
|
100
|
+
def __init__(self, data, seq_len):
|
101
|
+
super().__init__()
|
102
|
+
self.data = data
|
103
|
+
self.seq_len = seq_len
|
104
|
+
|
105
|
+
def __getitem__(self, index):
|
106
|
+
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
107
|
+
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
108
|
+
return full_seq.cuda()
|
109
|
+
|
110
|
+
def __len__(self):
|
111
|
+
return self.data.size(0) // self.seq_len
|
112
|
+
|
113
|
+
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
114
|
+
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
115
|
+
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
116
|
+
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
117
|
+
|
118
|
+
# optimizer
|
119
|
+
|
120
|
+
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
121
|
+
|
122
|
+
# training
|
123
|
+
|
124
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
125
|
+
model.train()
|
126
|
+
|
127
|
+
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
128
|
+
loss = model(next(train_loader), return_loss = True)
|
129
|
+
loss.backward()
|
130
|
+
|
131
|
+
print(f'training loss: {loss.item()}')
|
132
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
133
|
+
optim.step()
|
134
|
+
optim.zero_grad()
|
135
|
+
wandb.log(dict(loss = loss.item()))
|
136
|
+
|
137
|
+
if i % VALIDATE_EVERY == 0:
|
138
|
+
model.eval()
|
139
|
+
with torch.no_grad():
|
140
|
+
loss = model(next(val_loader), return_loss = True)
|
141
|
+
print(f'validation loss: {loss.item()}')
|
142
|
+
|
143
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
144
|
+
model.eval()
|
145
|
+
inp = random.choice(val_dataset)[:-1]
|
146
|
+
prime = decode_tokens(inp)
|
147
|
+
print(f'%s \n\n %s', (prime, '*' * 100))
|
148
|
+
|
149
|
+
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
150
|
+
output_str = decode_tokens(sample[0])
|
151
|
+
print(output_str)
|