titans-pytorch 0.0.14__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,173 @@
1
+ train_local.py
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # UV
100
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ #uv.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
118
+ .pdm.toml
119
+ .pdm-python
120
+ .pdm-build/
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
171
+
172
+ # PyPI configuration file
173
+ .pypirc
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,111 @@
1
+ Metadata-Version: 2.4
2
+ Name: titans-pytorch
3
+ Version: 0.0.14
4
+ Summary: Titans
5
+ Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: ninja
41
+ Requires-Dist: tensordict
42
+ Requires-Dist: torch>=2.2
43
+ Provides-Extra: examples
44
+ Requires-Dist: local-attention>=1.10.1; extra == 'examples'
45
+ Requires-Dist: taylor-series-linear-attention; extra == 'examples'
46
+ Requires-Dist: tqdm; extra == 'examples'
47
+ Requires-Dist: wandb; extra == 'examples'
48
+ Provides-Extra: test
49
+ Requires-Dist: pytest; extra == 'test'
50
+ Description-Content-Type: text/markdown
51
+
52
+ <img src="./fig2.png" width="400px"></img>
53
+
54
+ <img src="./fig1.png" width="400px"></img>
55
+
56
+ ## Titans - Pytorch (wip)
57
+
58
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
59
+
60
+ ## Install
61
+
62
+ ```bash
63
+ $ pip install titans-pytorch
64
+ ```
65
+
66
+ ## Usage
67
+
68
+ ```python
69
+ import torch
70
+ from titans_pytorch import NeuralMemory
71
+
72
+ mem = NeuralMemory(
73
+ dim = 384,
74
+ chunk_size = 64,
75
+ pre_rmsnorm = True
76
+ ).cuda()
77
+
78
+ seq = torch.randn(2, 1024, 384).cuda()
79
+ retrieved = mem(seq)
80
+
81
+ assert seq.shape == retrieved.shape
82
+ ```
83
+
84
+ ## Experiments
85
+
86
+ ```bash
87
+ $ pip install .[examples]
88
+ ```
89
+
90
+ For the SOTA linear attention, you will also need to run
91
+
92
+ ```bash
93
+ $ pip install -r requirements.txt
94
+ ```
95
+
96
+ Then modify `train.py` and run it to query nature
97
+
98
+ ```bash
99
+ $ python train.py
100
+ ```
101
+
102
+ ## Citations
103
+
104
+ ```bibtex
105
+ @inproceedings{Behrouz2024TitansLT,
106
+ title = {Titans: Learning to Memorize at Test Time},
107
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
108
+ year = {2024},
109
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
110
+ }
111
+ ```
@@ -0,0 +1,60 @@
1
+ <img src="./fig2.png" width="400px"></img>
2
+
3
+ <img src="./fig1.png" width="400px"></img>
4
+
5
+ ## Titans - Pytorch (wip)
6
+
7
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ $ pip install titans-pytorch
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ import torch
19
+ from titans_pytorch import NeuralMemory
20
+
21
+ mem = NeuralMemory(
22
+ dim = 384,
23
+ chunk_size = 64,
24
+ pre_rmsnorm = True
25
+ ).cuda()
26
+
27
+ seq = torch.randn(2, 1024, 384).cuda()
28
+ retrieved = mem(seq)
29
+
30
+ assert seq.shape == retrieved.shape
31
+ ```
32
+
33
+ ## Experiments
34
+
35
+ ```bash
36
+ $ pip install .[examples]
37
+ ```
38
+
39
+ For the SOTA linear attention, you will also need to run
40
+
41
+ ```bash
42
+ $ pip install -r requirements.txt
43
+ ```
44
+
45
+ Then modify `train.py` and run it to query nature
46
+
47
+ ```bash
48
+ $ python train.py
49
+ ```
50
+
51
+ ## Citations
52
+
53
+ ```bibtex
54
+ @inproceedings{Behrouz2024TitansLT,
55
+ title = {Titans: Learning to Memorize at Test Time},
56
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
57
+ year = {2024},
58
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
59
+ }
60
+ ```
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
Binary file
Binary file
@@ -0,0 +1,70 @@
1
+ [project]
2
+ name = "titans-pytorch"
3
+ version = "0.0.14"
4
+ description = "Titans"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'neural memory module',
15
+ 'test time training',
16
+ 'linear attention'
17
+ ]
18
+
19
+ classifiers=[
20
+ 'Development Status :: 4 - Beta',
21
+ 'Intended Audience :: Developers',
22
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23
+ 'License :: OSI Approved :: MIT License',
24
+ 'Programming Language :: Python :: 3.9',
25
+ ]
26
+
27
+ dependencies = [
28
+ "accelerated-scan>=0.2.0",
29
+ "einx>=0.3.0",
30
+ "einops>=0.8.0",
31
+ "Ninja",
32
+ "tensordict",
33
+ "torch>=2.2",
34
+ ]
35
+
36
+ [project.urls]
37
+ Homepage = "https://pypi.org/project/titans-pytorch/"
38
+ Repository = "https://github.com/lucidrains/titans-pytorch"
39
+
40
+ [project.optional-dependencies]
41
+
42
+ examples = [
43
+ "local-attention>=1.10.1",
44
+ "taylor-series-linear-attention",
45
+ "tqdm",
46
+ "wandb"
47
+ ]
48
+
49
+ test = [
50
+ "pytest"
51
+ ]
52
+
53
+ [tool.pytest.ini_options]
54
+ pythonpath = [
55
+ "."
56
+ ]
57
+
58
+ [build-system]
59
+ requires = ["hatchling"]
60
+ build-backend = "hatchling.build"
61
+
62
+ [tool.rye]
63
+ managed = true
64
+ dev-dependencies = []
65
+
66
+ [tool.hatch.metadata]
67
+ allow-direct-references = true
68
+
69
+ [tool.hatch.build.targets.wheel]
70
+ packages = ["titans_pytorch"]
@@ -0,0 +1 @@
1
+ pytorch-fast-transformers>=0.4.0
@@ -0,0 +1,3 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory
3
+ )
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+
8
+ # taken from S5-pytorch repository
9
+ # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
+
11
+ # helper functions
12
+
13
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
14
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
+ zeros = ((0, 0) * dims_from_right)
16
+ return F.pad(t, (*zeros, *pad), value = value)
17
+
18
+ # the operator that is needed
19
+
20
+ @torch.jit.script
21
+ def binary_operator(
22
+ a: tuple[Tensor, Tensor],
23
+ b: tuple[Tensor, Tensor]
24
+ ):
25
+ a_i, kv_i = a
26
+ a_j, kv_j = b
27
+ return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
28
+
29
+ # Pytorch impl. of jax.lax.associative_scan
30
+ # made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
31
+
32
+ def associative_scan(
33
+ operator: Callable,
34
+ elems: tuple[Tensor, Tensor]
35
+ ):
36
+ num_elems = int(elems[0].shape[1])
37
+
38
+ if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
39
+ raise ValueError('Array inputs to associative_scan must have the same '
40
+ 'first dimension. (saw: {})'
41
+ .format([elem.shape for elem in elems]))
42
+
43
+ def _scan(elems):
44
+ """Perform scan on `elems`."""
45
+ num_elems = elems[0].shape[1]
46
+
47
+ if num_elems < 2:
48
+ return elems
49
+
50
+ # Combine adjacent pairs of elements.
51
+
52
+ reduced_elems = operator(
53
+ [elem[:, :-1:2] for elem in elems],
54
+ [elem[:, 1::2] for elem in elems])
55
+
56
+ # Recursively compute scan for partially reduced tensors.
57
+
58
+ odd_elems = _scan(reduced_elems)
59
+
60
+ if num_elems % 2 == 0:
61
+ even_elems = operator(
62
+ [e[:, :-1] for e in odd_elems],
63
+ [e[:, 2::2] for e in elems])
64
+ else:
65
+ even_elems = operator(
66
+ odd_elems,
67
+ [e[:, 2::2] for e in elems])
68
+
69
+ # The first element of a scan is the same as the first element
70
+ # of the original `elems`.
71
+
72
+ even_elems = [
73
+ torch.cat([elem[:, :1], result], dim=1)
74
+ for (elem, result) in zip(elems, even_elems)]
75
+
76
+ return list(map(_interleave, even_elems, odd_elems))
77
+
78
+ return _scan(elems)
79
+
80
+ def _interleave(a, b):
81
+ a_axis_len, b_axis_len = a.shape[1], b.shape[1]
82
+ output_axis_len = a_axis_len + b_axis_len
83
+
84
+ if (a_axis_len == (b_axis_len + 1)):
85
+ b = pad_at_dim(b, (0, 1), dim = 1)
86
+
87
+ stacked = torch.stack([a, b], dim=2)
88
+ interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
+
90
+ return interleaved[:, :output_axis_len]
@@ -0,0 +1,408 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Linear, Module
9
+ from torch.func import functional_call, vmap, grad_and_value
10
+
11
+ from tensordict import TensorDict
12
+
13
+ from titans_pytorch.associative_scan import (
14
+ associative_scan,
15
+ binary_operator,
16
+ pad_at_dim
17
+ )
18
+
19
+ import einx
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange, Reduce
22
+
23
+ """
24
+ ein notation:
25
+ b - batch
26
+ n - sequence
27
+ d - feature dimension
28
+ c - intra-chunk
29
+ """
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ # functions
36
+
37
+ def exists(v):
38
+ return v is not None
39
+
40
+ def default(v, d):
41
+ return v if exists(v) else d
42
+
43
+ def round_down_multiple(seq, mult):
44
+ return seq // mult * mult
45
+
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
49
+ def pack_one_with_inverse(t, pattern):
50
+ packed, packed_shape = pack([t], pattern)
51
+
52
+ def inverse(out, inv_pattern = None):
53
+ inv_pattern = default(inv_pattern, pattern)
54
+ return unpack(out, packed_shape, inv_pattern)[0]
55
+
56
+ return packed, inverse
57
+
58
+ # classes
59
+
60
+ class MLP(Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ depth
65
+ ):
66
+ super().__init__()
67
+ self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
68
+
69
+ def forward(
70
+ self,
71
+ x
72
+ ):
73
+ for ind, weight in enumerate(self.weights):
74
+ is_first = ind == 0
75
+
76
+ if not is_first:
77
+ x = F.silu(x)
78
+
79
+ x = x @ weight
80
+
81
+ return x
82
+
83
+ # main neural memory
84
+
85
+ def default_loss_fn(pred, target):
86
+ return (pred - target).pow(2).mean(dim = -1).sum()
87
+
88
+ class NeuralMemory(Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ chunk_size = 1,
93
+ dim_head = None,
94
+ heads = 1,
95
+ model: Module | None = None,
96
+ store_memory_loss_fn: Callable = default_loss_fn,
97
+ pre_rmsnorm = True,
98
+ post_rmsnorm = True,
99
+ use_accelerated_scan = False,
100
+ default_mlp_kwargs: dict = dict(
101
+ depth = 4
102
+ )
103
+ ):
104
+ super().__init__()
105
+
106
+ # norms
107
+
108
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
109
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
110
+
111
+ self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
112
+
113
+ # maybe multi-headed
114
+
115
+ dim_head = default(dim_head, dim)
116
+ dim_inner = dim_head * heads
117
+
118
+ self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
119
+ self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
120
+ self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
121
+
122
+ # memory mlp
123
+
124
+ if not exists(model):
125
+ model = MLP(dim_head, **default_mlp_kwargs)
126
+
127
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
+
129
+ # the memory is the weights of the model
130
+
131
+ self.memory_model = model
132
+
133
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
134
+
135
+ self.chunk_size = chunk_size
136
+
137
+ # prepare function for per sample gradients from model above, using torch.func
138
+
139
+ def forward_and_loss(params, inputs, target):
140
+ pred = functional_call(self.memory_model, params, inputs)
141
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
+ return loss
143
+
144
+ self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
145
+
146
+ # queries for retrieving from the model
147
+
148
+ self.to_queries = LinearNoBias(dim, dim_inner)
149
+
150
+ # keys and values for storing to the model
151
+
152
+ self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
+ self.store_memory_loss_fn = store_memory_loss_fn
154
+
155
+ # learned adaptive learning rate and momentum
156
+ # todo - explore mlp layerwise learned lr / momentum
157
+
158
+ self.to_momentum = nn.Sequential(
159
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
160
+ LinearNoBias(dim, heads),
161
+ Rearrange('b n h -> (b h) n 1')
162
+ )
163
+
164
+ self.to_adaptive_step = nn.Sequential(
165
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
166
+ LinearNoBias(dim, heads),
167
+ Rearrange('b n h -> (b h) n')
168
+ )
169
+
170
+ # weight decay factor
171
+
172
+ self.to_decay_factor = nn.Sequential(
173
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
174
+ LinearNoBias(dim, heads),
175
+ Rearrange('b n h -> (b h) n 1')
176
+ )
177
+
178
+ # maybe use accelerated scan
179
+
180
+ self.use_accelerated_scan = use_accelerated_scan
181
+
182
+ def init_weights_and_momentum(self):
183
+ params = TensorDict(dict(self.memory_model.named_parameters()))
184
+
185
+ init_weights = params.clone().zero_()
186
+ init_momentum = params.clone().zero_()
187
+
188
+ return init_weights, init_momentum
189
+
190
+ def store_memories(
191
+ self,
192
+ seq,
193
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
194
+ ):
195
+
196
+ seq = self.store_norm(seq)
197
+
198
+ # curtail sequence by multiple of the chunk size
199
+ # only a complete chunk of the sequence provides the memory for the next chunk
200
+
201
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
202
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
203
+
204
+ seq = seq[:, :round_down_seq_len]
205
+
206
+ # curr weights + past weights, in the case that the initial weights are learned
207
+
208
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
209
+
210
+ past_state = tuple(TensorDict(d) for d in past_state)
211
+ past_weights, past_momentum = past_state
212
+
213
+ curr_weights = curr_weights + past_weights
214
+
215
+ # pack batch and sequence dimension
216
+
217
+ adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
218
+
219
+ adaptive_momentum = self.to_momentum(seq).sigmoid()
220
+ decay_factor = self.to_decay_factor(seq).sigmoid()
221
+
222
+ # keys and values
223
+
224
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
225
+
226
+ # maybe multi head
227
+
228
+ keys, values = map(self.split_heads, (keys, values))
229
+
230
+ batch = keys.shape[0]
231
+
232
+ # take care of chunking
233
+
234
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
235
+
236
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
+
238
+ grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
239
+
240
+ grads = TensorDict(grads)
241
+
242
+ # restore batch and sequence dimension
243
+
244
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
245
+
246
+ # multiply gradients with learned adaptive step size
247
+
248
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
249
+
250
+ # determine scan function
251
+
252
+ def default_associative_scan(gates, inputs):
253
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
254
+ return outputs
255
+
256
+ if self.use_accelerated_scan:
257
+ from accelerated_scan.triton import scan as triton_scan
258
+ from accelerated_scan.warp import scan as warp_scan
259
+
260
+ scan = triton_scan if seq.is_cuda else warp_scan
261
+
262
+ def accelerate_scan_fn(gates, inputs):
263
+ gates = gates.expand_as(inputs)
264
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
265
+
266
+ seq_len = gates.shape[-1]
267
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
268
+
269
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
+
272
+ outputs = scan(gates, inputs)
273
+
274
+ outputs = outputs[..., :seq_len]
275
+ outputs = rearrange(outputs, 'b d n -> b n d')
276
+ return outputs
277
+
278
+ scan_fn = accelerate_scan_fn
279
+ else:
280
+ scan_fn = default_associative_scan
281
+
282
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
283
+
284
+ next_momentum = TensorDict()
285
+ updates = TensorDict()
286
+
287
+ for param_name, surprise in surprises.items():
288
+
289
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
290
+
291
+ # derive momentum with associative scan - eq (10)
292
+
293
+ momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
294
+
295
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
296
+
297
+ update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
298
+
299
+ updates[param_name] = inverse_pack(update)
300
+ next_momentum[param_name] = inverse_pack(momentum)
301
+
302
+ # compute the next weight per batch
303
+
304
+ last_update = updates.apply(lambda t: t[:, -1])
305
+
306
+ next_state = (curr_weights + last_update, next_momentum)
307
+
308
+ return updates, next_state, aux_store_loss.mean() / chunk_size
309
+
310
+ def retrieve_memories(
311
+ self,
312
+ seq,
313
+ past_weights: dict[str, Tensor] | None = None,
314
+ ):
315
+ chunk_size = self.chunk_size
316
+ seq_len = seq.shape[1]
317
+
318
+ seq = self.retrieve_norm(seq)
319
+
320
+ assert seq_len >= chunk_size
321
+
322
+ seq = seq[:, (chunk_size - 1):]
323
+ curtailed_seq_len = seq.shape[-2]
324
+
325
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
326
+
327
+ padding = next_seq_len - curtailed_seq_len
328
+
329
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
330
+
331
+ # the parameters of the memory model stores the memories of the key / values
332
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
333
+
334
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
335
+
336
+ if exists(past_weights):
337
+ past_weights = TensorDict(past_weights)
338
+ assert past_weights.keys() == curr_weights.keys()
339
+
340
+ curr_weights = curr_weights + past_weights
341
+
342
+ # sequence Float['b n d'] to queries
343
+
344
+ queries = self.to_queries(seq)
345
+
346
+ # maybe multihead
347
+
348
+ queries = self.split_heads(queries)
349
+
350
+ batch = queries.shape[0]
351
+
352
+ # fetch values from memory model
353
+
354
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
355
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
356
+
357
+ # forward functional call
358
+
359
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
360
+
361
+ # reconstitute batch dimension
362
+
363
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
364
+
365
+ # maybe merge heads and combine
366
+
367
+ values = self.merge_heads(values)
368
+
369
+ values = self.combine_heads(values)
370
+
371
+ # post norm, somehow could not stabilize this without it, not in paper
372
+
373
+ values = self.post_rmsnorm(values)
374
+
375
+ # restore
376
+
377
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
+ values = values[:, :-padding]
379
+
380
+ return values
381
+
382
+ def forward(
383
+ self,
384
+ seq,
385
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
+ return_next_memories = False
387
+ ):
388
+ batch, seq_len = seq.shape[:2]
389
+
390
+ if seq_len < self.chunk_size:
391
+ return torch.zeros_like(seq)
392
+
393
+ if exists(past_state):
394
+ past_state = tuple(TensorDict(d) for d in past_state)
395
+
396
+ if not exists(past_state):
397
+ past_state = self.init_weights_and_momentum()
398
+
399
+ updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
400
+
401
+ past_weights, _ = past_state
402
+
403
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
404
+
405
+ if not return_next_memories:
406
+ return retrieved
407
+
408
+ return retrieved, next_memories, aux_kv_mse_loss
@@ -0,0 +1,151 @@
1
+ import random
2
+ import tqdm
3
+ import gzip
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from local_attention import LocalTransformer
13
+
14
+ from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
+
16
+ from titans_pytorch.titans import NeuralMemory
17
+
18
+ # constants
19
+
20
+ NUM_BATCHES = int(1e5)
21
+ BATCH_SIZE = 4
22
+ GRADIENT_ACCUMULATE_EVERY = 4
23
+ LEARNING_RATE = 2e-4
24
+ VALIDATE_EVERY = 100
25
+ GENERATE_EVERY = 500
26
+ GENERATE_LENGTH = 512
27
+ SHOULD_GENERATE = False
28
+ SEQ_LEN = 512
29
+
30
+ PROJECT_NAME = 'titans-neural-memory'
31
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
32
+ GLOBAL_LAYERS = (4, 5)
33
+ USE_TITANS_MEMORY = True
34
+ NEURAL_MEMORY_DEPTH = 2
35
+ WINDOW_SIZE = 64
36
+ RUN_NAME = 'neural memory'
37
+
38
+ # wandb experiment tracker
39
+
40
+ import wandb
41
+ wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
42
+ wandb.run.name = RUN_NAME
43
+ wandb.run.save()
44
+
45
+ # helpers
46
+
47
+ def cycle(loader):
48
+ while True:
49
+ for data in loader:
50
+ yield data
51
+
52
+ def decode_token(token):
53
+ return str(chr(max(32, token)))
54
+
55
+ def decode_tokens(tokens):
56
+ return ''.join(list(map(decode_token, tokens)))
57
+
58
+ # instantiate GPT-like decoder model
59
+
60
+ titans_neural_memory = NeuralMemory(
61
+ dim = 384,
62
+ chunk_size = WINDOW_SIZE,
63
+ pre_rmsnorm = True,
64
+ post_rmsnorm = True,
65
+ dim_head = 32,
66
+ heads = 8,
67
+ use_accelerated_scan = True,
68
+ default_mlp_kwargs = dict(
69
+ depth = NEURAL_MEMORY_DEPTH
70
+ )
71
+ )
72
+
73
+ linear_attn = TaylorSeriesLinearAttn(
74
+ dim = 384,
75
+ dim_head = 16,
76
+ heads = 16,
77
+ causal = True,
78
+ prenorm = True
79
+ )
80
+
81
+ model = LocalTransformer(
82
+ num_tokens = 256,
83
+ dim = 384,
84
+ depth = 8,
85
+ causal = True,
86
+ local_attn_window_size = WINDOW_SIZE,
87
+ max_seq_len = SEQ_LEN,
88
+ global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
89
+ layers_insert_global_attn = GLOBAL_LAYERS
90
+ ).cuda()
91
+
92
+ # prepare enwik8 data
93
+
94
+ with gzip.open('./data/enwik8.gz') as file:
95
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
96
+ data_train, data_val = np.split(data, [int(90e6)])
97
+ data_train, data_val = map(torch.from_numpy, (data_train, data_val))
98
+
99
+ class TextSamplerDataset(Dataset):
100
+ def __init__(self, data, seq_len):
101
+ super().__init__()
102
+ self.data = data
103
+ self.seq_len = seq_len
104
+
105
+ def __getitem__(self, index):
106
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
107
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
108
+ return full_seq.cuda()
109
+
110
+ def __len__(self):
111
+ return self.data.size(0) // self.seq_len
112
+
113
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
114
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
115
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
116
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
117
+
118
+ # optimizer
119
+
120
+ optim = Adam(model.parameters(), lr=LEARNING_RATE)
121
+
122
+ # training
123
+
124
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
125
+ model.train()
126
+
127
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
128
+ loss = model(next(train_loader), return_loss = True)
129
+ loss.backward()
130
+
131
+ print(f'training loss: {loss.item()}')
132
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
133
+ optim.step()
134
+ optim.zero_grad()
135
+ wandb.log(dict(loss = loss.item()))
136
+
137
+ if i % VALIDATE_EVERY == 0:
138
+ model.eval()
139
+ with torch.no_grad():
140
+ loss = model(next(val_loader), return_loss = True)
141
+ print(f'validation loss: {loss.item()}')
142
+
143
+ if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
144
+ model.eval()
145
+ inp = random.choice(val_dataset)[:-1]
146
+ prime = decode_tokens(inp)
147
+ print(f'%s \n\n %s', (prime, '*' * 100))
148
+
149
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
150
+ output_str = decode_tokens(sample[0])
151
+ print(output_str)