titans-pytorch 0.0.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,173 @@
1
+ train_local.py
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # UV
100
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ #uv.lock
104
+
105
+ # poetry
106
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
108
+ # commonly ignored for libraries.
109
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110
+ #poetry.lock
111
+
112
+ # pdm
113
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
114
+ #pdm.lock
115
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
116
+ # in version control.
117
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
118
+ .pdm.toml
119
+ .pdm-python
120
+ .pdm-build/
121
+
122
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
123
+ __pypackages__/
124
+
125
+ # Celery stuff
126
+ celerybeat-schedule
127
+ celerybeat.pid
128
+
129
+ # SageMath parsed files
130
+ *.sage.py
131
+
132
+ # Environments
133
+ .env
134
+ .venv
135
+ env/
136
+ venv/
137
+ ENV/
138
+ env.bak/
139
+ venv.bak/
140
+
141
+ # Spyder project settings
142
+ .spyderproject
143
+ .spyproject
144
+
145
+ # Rope project settings
146
+ .ropeproject
147
+
148
+ # mkdocs documentation
149
+ /site
150
+
151
+ # mypy
152
+ .mypy_cache/
153
+ .dmypy.json
154
+ dmypy.json
155
+
156
+ # Pyre type checker
157
+ .pyre/
158
+
159
+ # pytype static type analyzer
160
+ .pytype/
161
+
162
+ # Cython debug symbols
163
+ cython_debug/
164
+
165
+ # PyCharm
166
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
169
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170
+ #.idea/
171
+
172
+ # PyPI configuration file
173
+ .pypirc
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,111 @@
1
+ Metadata-Version: 2.4
2
+ Name: titans-pytorch
3
+ Version: 0.0.14
4
+ Summary: Titans
5
+ Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: accelerated-scan>=0.2.0
38
+ Requires-Dist: einops>=0.8.0
39
+ Requires-Dist: einx>=0.3.0
40
+ Requires-Dist: ninja
41
+ Requires-Dist: tensordict
42
+ Requires-Dist: torch>=2.2
43
+ Provides-Extra: examples
44
+ Requires-Dist: local-attention>=1.10.1; extra == 'examples'
45
+ Requires-Dist: taylor-series-linear-attention; extra == 'examples'
46
+ Requires-Dist: tqdm; extra == 'examples'
47
+ Requires-Dist: wandb; extra == 'examples'
48
+ Provides-Extra: test
49
+ Requires-Dist: pytest; extra == 'test'
50
+ Description-Content-Type: text/markdown
51
+
52
+ <img src="./fig2.png" width="400px"></img>
53
+
54
+ <img src="./fig1.png" width="400px"></img>
55
+
56
+ ## Titans - Pytorch (wip)
57
+
58
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
59
+
60
+ ## Install
61
+
62
+ ```bash
63
+ $ pip install titans-pytorch
64
+ ```
65
+
66
+ ## Usage
67
+
68
+ ```python
69
+ import torch
70
+ from titans_pytorch import NeuralMemory
71
+
72
+ mem = NeuralMemory(
73
+ dim = 384,
74
+ chunk_size = 64,
75
+ pre_rmsnorm = True
76
+ ).cuda()
77
+
78
+ seq = torch.randn(2, 1024, 384).cuda()
79
+ retrieved = mem(seq)
80
+
81
+ assert seq.shape == retrieved.shape
82
+ ```
83
+
84
+ ## Experiments
85
+
86
+ ```bash
87
+ $ pip install .[examples]
88
+ ```
89
+
90
+ For the SOTA linear attention, you will also need to run
91
+
92
+ ```bash
93
+ $ pip install -r requirements.txt
94
+ ```
95
+
96
+ Then modify `train.py` and run it to query nature
97
+
98
+ ```bash
99
+ $ python train.py
100
+ ```
101
+
102
+ ## Citations
103
+
104
+ ```bibtex
105
+ @inproceedings{Behrouz2024TitansLT,
106
+ title = {Titans: Learning to Memorize at Test Time},
107
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
108
+ year = {2024},
109
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
110
+ }
111
+ ```
@@ -0,0 +1,60 @@
1
+ <img src="./fig2.png" width="400px"></img>
2
+
3
+ <img src="./fig1.png" width="400px"></img>
4
+
5
+ ## Titans - Pytorch (wip)
6
+
7
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ $ pip install titans-pytorch
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ import torch
19
+ from titans_pytorch import NeuralMemory
20
+
21
+ mem = NeuralMemory(
22
+ dim = 384,
23
+ chunk_size = 64,
24
+ pre_rmsnorm = True
25
+ ).cuda()
26
+
27
+ seq = torch.randn(2, 1024, 384).cuda()
28
+ retrieved = mem(seq)
29
+
30
+ assert seq.shape == retrieved.shape
31
+ ```
32
+
33
+ ## Experiments
34
+
35
+ ```bash
36
+ $ pip install .[examples]
37
+ ```
38
+
39
+ For the SOTA linear attention, you will also need to run
40
+
41
+ ```bash
42
+ $ pip install -r requirements.txt
43
+ ```
44
+
45
+ Then modify `train.py` and run it to query nature
46
+
47
+ ```bash
48
+ $ python train.py
49
+ ```
50
+
51
+ ## Citations
52
+
53
+ ```bibtex
54
+ @inproceedings{Behrouz2024TitansLT,
55
+ title = {Titans: Learning to Memorize at Test Time},
56
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
57
+ year = {2024},
58
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
59
+ }
60
+ ```
@@ -0,0 +1,3 @@
1
+ # Data source
2
+
3
+ The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
Binary file
Binary file
Binary file
@@ -0,0 +1,70 @@
1
+ [project]
2
+ name = "titans-pytorch"
3
+ version = "0.0.14"
4
+ description = "Titans"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'neural memory module',
15
+ 'test time training',
16
+ 'linear attention'
17
+ ]
18
+
19
+ classifiers=[
20
+ 'Development Status :: 4 - Beta',
21
+ 'Intended Audience :: Developers',
22
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23
+ 'License :: OSI Approved :: MIT License',
24
+ 'Programming Language :: Python :: 3.9',
25
+ ]
26
+
27
+ dependencies = [
28
+ "accelerated-scan>=0.2.0",
29
+ "einx>=0.3.0",
30
+ "einops>=0.8.0",
31
+ "Ninja",
32
+ "tensordict",
33
+ "torch>=2.2",
34
+ ]
35
+
36
+ [project.urls]
37
+ Homepage = "https://pypi.org/project/titans-pytorch/"
38
+ Repository = "https://github.com/lucidrains/titans-pytorch"
39
+
40
+ [project.optional-dependencies]
41
+
42
+ examples = [
43
+ "local-attention>=1.10.1",
44
+ "taylor-series-linear-attention",
45
+ "tqdm",
46
+ "wandb"
47
+ ]
48
+
49
+ test = [
50
+ "pytest"
51
+ ]
52
+
53
+ [tool.pytest.ini_options]
54
+ pythonpath = [
55
+ "."
56
+ ]
57
+
58
+ [build-system]
59
+ requires = ["hatchling"]
60
+ build-backend = "hatchling.build"
61
+
62
+ [tool.rye]
63
+ managed = true
64
+ dev-dependencies = []
65
+
66
+ [tool.hatch.metadata]
67
+ allow-direct-references = true
68
+
69
+ [tool.hatch.build.targets.wheel]
70
+ packages = ["titans_pytorch"]
@@ -0,0 +1 @@
1
+ pytorch-fast-transformers>=0.4.0
@@ -0,0 +1,3 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory
3
+ )
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+
8
+ # taken from S5-pytorch repository
9
+ # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
+
11
+ # helper functions
12
+
13
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
14
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
+ zeros = ((0, 0) * dims_from_right)
16
+ return F.pad(t, (*zeros, *pad), value = value)
17
+
18
+ # the operator that is needed
19
+
20
+ @torch.jit.script
21
+ def binary_operator(
22
+ a: tuple[Tensor, Tensor],
23
+ b: tuple[Tensor, Tensor]
24
+ ):
25
+ a_i, kv_i = a
26
+ a_j, kv_j = b
27
+ return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
28
+
29
+ # Pytorch impl. of jax.lax.associative_scan
30
+ # made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
31
+
32
+ def associative_scan(
33
+ operator: Callable,
34
+ elems: tuple[Tensor, Tensor]
35
+ ):
36
+ num_elems = int(elems[0].shape[1])
37
+
38
+ if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
39
+ raise ValueError('Array inputs to associative_scan must have the same '
40
+ 'first dimension. (saw: {})'
41
+ .format([elem.shape for elem in elems]))
42
+
43
+ def _scan(elems):
44
+ """Perform scan on `elems`."""
45
+ num_elems = elems[0].shape[1]
46
+
47
+ if num_elems < 2:
48
+ return elems
49
+
50
+ # Combine adjacent pairs of elements.
51
+
52
+ reduced_elems = operator(
53
+ [elem[:, :-1:2] for elem in elems],
54
+ [elem[:, 1::2] for elem in elems])
55
+
56
+ # Recursively compute scan for partially reduced tensors.
57
+
58
+ odd_elems = _scan(reduced_elems)
59
+
60
+ if num_elems % 2 == 0:
61
+ even_elems = operator(
62
+ [e[:, :-1] for e in odd_elems],
63
+ [e[:, 2::2] for e in elems])
64
+ else:
65
+ even_elems = operator(
66
+ odd_elems,
67
+ [e[:, 2::2] for e in elems])
68
+
69
+ # The first element of a scan is the same as the first element
70
+ # of the original `elems`.
71
+
72
+ even_elems = [
73
+ torch.cat([elem[:, :1], result], dim=1)
74
+ for (elem, result) in zip(elems, even_elems)]
75
+
76
+ return list(map(_interleave, even_elems, odd_elems))
77
+
78
+ return _scan(elems)
79
+
80
+ def _interleave(a, b):
81
+ a_axis_len, b_axis_len = a.shape[1], b.shape[1]
82
+ output_axis_len = a_axis_len + b_axis_len
83
+
84
+ if (a_axis_len == (b_axis_len + 1)):
85
+ b = pad_at_dim(b, (0, 1), dim = 1)
86
+
87
+ stacked = torch.stack([a, b], dim=2)
88
+ interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
+
90
+ return interleaved[:, :output_axis_len]
@@ -0,0 +1,408 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch import nn, Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Linear, Module
9
+ from torch.func import functional_call, vmap, grad_and_value
10
+
11
+ from tensordict import TensorDict
12
+
13
+ from titans_pytorch.associative_scan import (
14
+ associative_scan,
15
+ binary_operator,
16
+ pad_at_dim
17
+ )
18
+
19
+ import einx
20
+ from einops import rearrange, pack, unpack
21
+ from einops.layers.torch import Rearrange, Reduce
22
+
23
+ """
24
+ ein notation:
25
+ b - batch
26
+ n - sequence
27
+ d - feature dimension
28
+ c - intra-chunk
29
+ """
30
+
31
+ # constants
32
+
33
+ LinearNoBias = partial(Linear, bias = False)
34
+
35
+ # functions
36
+
37
+ def exists(v):
38
+ return v is not None
39
+
40
+ def default(v, d):
41
+ return v if exists(v) else d
42
+
43
+ def round_down_multiple(seq, mult):
44
+ return seq // mult * mult
45
+
46
+ def round_up_multiple(seq, mult):
47
+ return math.ceil(seq / mult) * mult
48
+
49
+ def pack_one_with_inverse(t, pattern):
50
+ packed, packed_shape = pack([t], pattern)
51
+
52
+ def inverse(out, inv_pattern = None):
53
+ inv_pattern = default(inv_pattern, pattern)
54
+ return unpack(out, packed_shape, inv_pattern)[0]
55
+
56
+ return packed, inverse
57
+
58
+ # classes
59
+
60
+ class MLP(Module):
61
+ def __init__(
62
+ self,
63
+ dim,
64
+ depth
65
+ ):
66
+ super().__init__()
67
+ self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
68
+
69
+ def forward(
70
+ self,
71
+ x
72
+ ):
73
+ for ind, weight in enumerate(self.weights):
74
+ is_first = ind == 0
75
+
76
+ if not is_first:
77
+ x = F.silu(x)
78
+
79
+ x = x @ weight
80
+
81
+ return x
82
+
83
+ # main neural memory
84
+
85
+ def default_loss_fn(pred, target):
86
+ return (pred - target).pow(2).mean(dim = -1).sum()
87
+
88
+ class NeuralMemory(Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ chunk_size = 1,
93
+ dim_head = None,
94
+ heads = 1,
95
+ model: Module | None = None,
96
+ store_memory_loss_fn: Callable = default_loss_fn,
97
+ pre_rmsnorm = True,
98
+ post_rmsnorm = True,
99
+ use_accelerated_scan = False,
100
+ default_mlp_kwargs: dict = dict(
101
+ depth = 4
102
+ )
103
+ ):
104
+ super().__init__()
105
+
106
+ # norms
107
+
108
+ self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
109
+ self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
110
+
111
+ self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
112
+
113
+ # maybe multi-headed
114
+
115
+ dim_head = default(dim_head, dim)
116
+ dim_inner = dim_head * heads
117
+
118
+ self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
119
+ self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
120
+ self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
121
+
122
+ # memory mlp
123
+
124
+ if not exists(model):
125
+ model = MLP(dim_head, **default_mlp_kwargs)
126
+
127
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
128
+
129
+ # the memory is the weights of the model
130
+
131
+ self.memory_model = model
132
+
133
+ # the chunk size within the paper where adaptive step, momentum, weight decay are shared
134
+
135
+ self.chunk_size = chunk_size
136
+
137
+ # prepare function for per sample gradients from model above, using torch.func
138
+
139
+ def forward_and_loss(params, inputs, target):
140
+ pred = functional_call(self.memory_model, params, inputs)
141
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
142
+ return loss
143
+
144
+ self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
145
+
146
+ # queries for retrieving from the model
147
+
148
+ self.to_queries = LinearNoBias(dim, dim_inner)
149
+
150
+ # keys and values for storing to the model
151
+
152
+ self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
153
+ self.store_memory_loss_fn = store_memory_loss_fn
154
+
155
+ # learned adaptive learning rate and momentum
156
+ # todo - explore mlp layerwise learned lr / momentum
157
+
158
+ self.to_momentum = nn.Sequential(
159
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
160
+ LinearNoBias(dim, heads),
161
+ Rearrange('b n h -> (b h) n 1')
162
+ )
163
+
164
+ self.to_adaptive_step = nn.Sequential(
165
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
166
+ LinearNoBias(dim, heads),
167
+ Rearrange('b n h -> (b h) n')
168
+ )
169
+
170
+ # weight decay factor
171
+
172
+ self.to_decay_factor = nn.Sequential(
173
+ Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
174
+ LinearNoBias(dim, heads),
175
+ Rearrange('b n h -> (b h) n 1')
176
+ )
177
+
178
+ # maybe use accelerated scan
179
+
180
+ self.use_accelerated_scan = use_accelerated_scan
181
+
182
+ def init_weights_and_momentum(self):
183
+ params = TensorDict(dict(self.memory_model.named_parameters()))
184
+
185
+ init_weights = params.clone().zero_()
186
+ init_momentum = params.clone().zero_()
187
+
188
+ return init_weights, init_momentum
189
+
190
+ def store_memories(
191
+ self,
192
+ seq,
193
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
194
+ ):
195
+
196
+ seq = self.store_norm(seq)
197
+
198
+ # curtail sequence by multiple of the chunk size
199
+ # only a complete chunk of the sequence provides the memory for the next chunk
200
+
201
+ seq_len, chunk_size = seq.shape[-2], self.chunk_size
202
+ round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
203
+
204
+ seq = seq[:, :round_down_seq_len]
205
+
206
+ # curr weights + past weights, in the case that the initial weights are learned
207
+
208
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
209
+
210
+ past_state = tuple(TensorDict(d) for d in past_state)
211
+ past_weights, past_momentum = past_state
212
+
213
+ curr_weights = curr_weights + past_weights
214
+
215
+ # pack batch and sequence dimension
216
+
217
+ adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
218
+
219
+ adaptive_momentum = self.to_momentum(seq).sigmoid()
220
+ decay_factor = self.to_decay_factor(seq).sigmoid()
221
+
222
+ # keys and values
223
+
224
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
225
+
226
+ # maybe multi head
227
+
228
+ keys, values = map(self.split_heads, (keys, values))
229
+
230
+ batch = keys.shape[0]
231
+
232
+ # take care of chunking
233
+
234
+ keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
235
+
236
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
237
+
238
+ grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
239
+
240
+ grads = TensorDict(grads)
241
+
242
+ # restore batch and sequence dimension
243
+
244
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
245
+
246
+ # multiply gradients with learned adaptive step size
247
+
248
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
249
+
250
+ # determine scan function
251
+
252
+ def default_associative_scan(gates, inputs):
253
+ _, outputs = associative_scan(binary_operator, (gates, inputs))
254
+ return outputs
255
+
256
+ if self.use_accelerated_scan:
257
+ from accelerated_scan.triton import scan as triton_scan
258
+ from accelerated_scan.warp import scan as warp_scan
259
+
260
+ scan = triton_scan if seq.is_cuda else warp_scan
261
+
262
+ def accelerate_scan_fn(gates, inputs):
263
+ gates = gates.expand_as(inputs)
264
+ gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
265
+
266
+ seq_len = gates.shape[-1]
267
+ next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
268
+
269
+ gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
270
+ inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
271
+
272
+ outputs = scan(gates, inputs)
273
+
274
+ outputs = outputs[..., :seq_len]
275
+ outputs = rearrange(outputs, 'b d n -> b n d')
276
+ return outputs
277
+
278
+ scan_fn = accelerate_scan_fn
279
+ else:
280
+ scan_fn = default_associative_scan
281
+
282
+ # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
283
+
284
+ next_momentum = TensorDict()
285
+ updates = TensorDict()
286
+
287
+ for param_name, surprise in surprises.items():
288
+
289
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
290
+
291
+ # derive momentum with associative scan - eq (10)
292
+
293
+ momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
294
+
295
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
296
+
297
+ update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
298
+
299
+ updates[param_name] = inverse_pack(update)
300
+ next_momentum[param_name] = inverse_pack(momentum)
301
+
302
+ # compute the next weight per batch
303
+
304
+ last_update = updates.apply(lambda t: t[:, -1])
305
+
306
+ next_state = (curr_weights + last_update, next_momentum)
307
+
308
+ return updates, next_state, aux_store_loss.mean() / chunk_size
309
+
310
+ def retrieve_memories(
311
+ self,
312
+ seq,
313
+ past_weights: dict[str, Tensor] | None = None,
314
+ ):
315
+ chunk_size = self.chunk_size
316
+ seq_len = seq.shape[1]
317
+
318
+ seq = self.retrieve_norm(seq)
319
+
320
+ assert seq_len >= chunk_size
321
+
322
+ seq = seq[:, (chunk_size - 1):]
323
+ curtailed_seq_len = seq.shape[-2]
324
+
325
+ next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
326
+
327
+ padding = next_seq_len - curtailed_seq_len
328
+
329
+ seq = pad_at_dim(seq, (0, padding), dim = 1)
330
+
331
+ # the parameters of the memory model stores the memories of the key / values
332
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
333
+
334
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
335
+
336
+ if exists(past_weights):
337
+ past_weights = TensorDict(past_weights)
338
+ assert past_weights.keys() == curr_weights.keys()
339
+
340
+ curr_weights = curr_weights + past_weights
341
+
342
+ # sequence Float['b n d'] to queries
343
+
344
+ queries = self.to_queries(seq)
345
+
346
+ # maybe multihead
347
+
348
+ queries = self.split_heads(queries)
349
+
350
+ batch = queries.shape[0]
351
+
352
+ # fetch values from memory model
353
+
354
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
355
+ queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
356
+
357
+ # forward functional call
358
+
359
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
360
+
361
+ # reconstitute batch dimension
362
+
363
+ values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
364
+
365
+ # maybe merge heads and combine
366
+
367
+ values = self.merge_heads(values)
368
+
369
+ values = self.combine_heads(values)
370
+
371
+ # post norm, somehow could not stabilize this without it, not in paper
372
+
373
+ values = self.post_rmsnorm(values)
374
+
375
+ # restore
376
+
377
+ values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
378
+ values = values[:, :-padding]
379
+
380
+ return values
381
+
382
+ def forward(
383
+ self,
384
+ seq,
385
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
386
+ return_next_memories = False
387
+ ):
388
+ batch, seq_len = seq.shape[:2]
389
+
390
+ if seq_len < self.chunk_size:
391
+ return torch.zeros_like(seq)
392
+
393
+ if exists(past_state):
394
+ past_state = tuple(TensorDict(d) for d in past_state)
395
+
396
+ if not exists(past_state):
397
+ past_state = self.init_weights_and_momentum()
398
+
399
+ updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
400
+
401
+ past_weights, _ = past_state
402
+
403
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
404
+
405
+ if not return_next_memories:
406
+ return retrieved
407
+
408
+ return retrieved, next_memories, aux_kv_mse_loss
@@ -0,0 +1,151 @@
1
+ import random
2
+ import tqdm
3
+ import gzip
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.optim import Adam
9
+ from torch.nn import functional as F
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from local_attention import LocalTransformer
13
+
14
+ from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
+
16
+ from titans_pytorch.titans import NeuralMemory
17
+
18
+ # constants
19
+
20
+ NUM_BATCHES = int(1e5)
21
+ BATCH_SIZE = 4
22
+ GRADIENT_ACCUMULATE_EVERY = 4
23
+ LEARNING_RATE = 2e-4
24
+ VALIDATE_EVERY = 100
25
+ GENERATE_EVERY = 500
26
+ GENERATE_LENGTH = 512
27
+ SHOULD_GENERATE = False
28
+ SEQ_LEN = 512
29
+
30
+ PROJECT_NAME = 'titans-neural-memory'
31
+ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
32
+ GLOBAL_LAYERS = (4, 5)
33
+ USE_TITANS_MEMORY = True
34
+ NEURAL_MEMORY_DEPTH = 2
35
+ WINDOW_SIZE = 64
36
+ RUN_NAME = 'neural memory'
37
+
38
+ # wandb experiment tracker
39
+
40
+ import wandb
41
+ wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
42
+ wandb.run.name = RUN_NAME
43
+ wandb.run.save()
44
+
45
+ # helpers
46
+
47
+ def cycle(loader):
48
+ while True:
49
+ for data in loader:
50
+ yield data
51
+
52
+ def decode_token(token):
53
+ return str(chr(max(32, token)))
54
+
55
+ def decode_tokens(tokens):
56
+ return ''.join(list(map(decode_token, tokens)))
57
+
58
+ # instantiate GPT-like decoder model
59
+
60
+ titans_neural_memory = NeuralMemory(
61
+ dim = 384,
62
+ chunk_size = WINDOW_SIZE,
63
+ pre_rmsnorm = True,
64
+ post_rmsnorm = True,
65
+ dim_head = 32,
66
+ heads = 8,
67
+ use_accelerated_scan = True,
68
+ default_mlp_kwargs = dict(
69
+ depth = NEURAL_MEMORY_DEPTH
70
+ )
71
+ )
72
+
73
+ linear_attn = TaylorSeriesLinearAttn(
74
+ dim = 384,
75
+ dim_head = 16,
76
+ heads = 16,
77
+ causal = True,
78
+ prenorm = True
79
+ )
80
+
81
+ model = LocalTransformer(
82
+ num_tokens = 256,
83
+ dim = 384,
84
+ depth = 8,
85
+ causal = True,
86
+ local_attn_window_size = WINDOW_SIZE,
87
+ max_seq_len = SEQ_LEN,
88
+ global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
89
+ layers_insert_global_attn = GLOBAL_LAYERS
90
+ ).cuda()
91
+
92
+ # prepare enwik8 data
93
+
94
+ with gzip.open('./data/enwik8.gz') as file:
95
+ data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
96
+ data_train, data_val = np.split(data, [int(90e6)])
97
+ data_train, data_val = map(torch.from_numpy, (data_train, data_val))
98
+
99
+ class TextSamplerDataset(Dataset):
100
+ def __init__(self, data, seq_len):
101
+ super().__init__()
102
+ self.data = data
103
+ self.seq_len = seq_len
104
+
105
+ def __getitem__(self, index):
106
+ rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
107
+ full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
108
+ return full_seq.cuda()
109
+
110
+ def __len__(self):
111
+ return self.data.size(0) // self.seq_len
112
+
113
+ train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
114
+ val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
115
+ train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
116
+ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
117
+
118
+ # optimizer
119
+
120
+ optim = Adam(model.parameters(), lr=LEARNING_RATE)
121
+
122
+ # training
123
+
124
+ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
125
+ model.train()
126
+
127
+ for __ in range(GRADIENT_ACCUMULATE_EVERY):
128
+ loss = model(next(train_loader), return_loss = True)
129
+ loss.backward()
130
+
131
+ print(f'training loss: {loss.item()}')
132
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
133
+ optim.step()
134
+ optim.zero_grad()
135
+ wandb.log(dict(loss = loss.item()))
136
+
137
+ if i % VALIDATE_EVERY == 0:
138
+ model.eval()
139
+ with torch.no_grad():
140
+ loss = model(next(val_loader), return_loss = True)
141
+ print(f'validation loss: {loss.item()}')
142
+
143
+ if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
144
+ model.eval()
145
+ inp = random.choice(val_dataset)[:-1]
146
+ prime = decode_tokens(inp)
147
+ print(f'%s \n\n %s', (prime, '*' * 100))
148
+
149
+ sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
150
+ output_str = decode_tokens(sample[0])
151
+ print(output_str)