titans-pytorch 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,36 @@
1
+ # This workflow will upload a Python Package using Twine when a release is created
2
+ # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3
+
4
+ # This workflow uses actions that are not certified by GitHub.
5
+ # They are provided by a third-party and are governed by
6
+ # separate terms of service, privacy policy, and support
7
+ # documentation.
8
+
9
+ name: Upload Python Package
10
+
11
+ on:
12
+ release:
13
+ types: [published]
14
+
15
+ jobs:
16
+ deploy:
17
+
18
+ runs-on: ubuntu-latest
19
+
20
+ steps:
21
+ - uses: actions/checkout@v2
22
+ - name: Set up Python
23
+ uses: actions/setup-python@v2
24
+ with:
25
+ python-version: '3.x'
26
+ - name: Install dependencies
27
+ run: |
28
+ python -m pip install --upgrade pip
29
+ pip install build
30
+ - name: Build package
31
+ run: python -m build
32
+ - name: Publish package
33
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34
+ with:
35
+ user: __token__
36
+ password: ${{ secrets.PYPI_API_TOKEN }}
@@ -0,0 +1,171 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # PyPI configuration file
171
+ .pypirc
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,84 @@
1
+ Metadata-Version: 2.4
2
+ Name: titans-pytorch
3
+ Version: 0.0.1
4
+ Summary: Titans
5
+ Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: einx>=0.3.0
39
+ Requires-Dist: tensordict>=0.6.2
40
+ Requires-Dist: torch>=2.3
41
+ Provides-Extra: examples
42
+ Provides-Extra: test
43
+ Requires-Dist: pytest; extra == 'test'
44
+ Description-Content-Type: text/markdown
45
+
46
+ <img src="./fig2.png" width="400px"></img>
47
+
48
+ <img src="./fig1.png" width="400px"></img>
49
+
50
+ ## Titans - Pytorch (wip)
51
+
52
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
53
+
54
+ ## Install
55
+
56
+ ```bash
57
+ $ pip install titans-pytorch
58
+ ```
59
+
60
+ ## Usage
61
+
62
+ ```python
63
+ import torch
64
+ from titans_pytorch import NeuralMemory
65
+
66
+ x = torch.randn(2, 64, 32)
67
+
68
+ mem = NeuralMemory(32)
69
+
70
+ out = mem(x)
71
+
72
+ assert x.shape == out.shape
73
+ ```
74
+
75
+ ## Citations
76
+
77
+ ```bibtex
78
+ @inproceedings{Behrouz2024TitansLT,
79
+ title = {Titans: Learning to Memorize at Test Time},
80
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
81
+ year = {2024},
82
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
83
+ }
84
+ ```
@@ -0,0 +1,39 @@
1
+ <img src="./fig2.png" width="400px"></img>
2
+
3
+ <img src="./fig1.png" width="400px"></img>
4
+
5
+ ## Titans - Pytorch (wip)
6
+
7
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ $ pip install titans-pytorch
13
+ ```
14
+
15
+ ## Usage
16
+
17
+ ```python
18
+ import torch
19
+ from titans_pytorch import NeuralMemory
20
+
21
+ x = torch.randn(2, 64, 32)
22
+
23
+ mem = NeuralMemory(32)
24
+
25
+ out = mem(x)
26
+
27
+ assert x.shape == out.shape
28
+ ```
29
+
30
+ ## Citations
31
+
32
+ ```bibtex
33
+ @inproceedings{Behrouz2024TitansLT,
34
+ title = {Titans: Learning to Memorize at Test Time},
35
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
36
+ year = {2024},
37
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
38
+ }
39
+ ```
Binary file
Binary file
@@ -0,0 +1,61 @@
1
+ [project]
2
+ name = "titans-pytorch"
3
+ version = "0.0.1"
4
+ description = "Titans"
5
+ authors = [
6
+ { name = "Phil Wang", email = "lucidrains@gmail.com" }
7
+ ]
8
+ readme = "README.md"
9
+ requires-python = ">= 3.9"
10
+ license = { file = "LICENSE" }
11
+ keywords = [
12
+ 'artificial intelligence',
13
+ 'deep learning',
14
+ 'neural memory module',
15
+ 'test time training',
16
+ 'linear attention'
17
+ ]
18
+
19
+ classifiers=[
20
+ 'Development Status :: 4 - Beta',
21
+ 'Intended Audience :: Developers',
22
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23
+ 'License :: OSI Approved :: MIT License',
24
+ 'Programming Language :: Python :: 3.9',
25
+ ]
26
+
27
+ dependencies = [
28
+ "einx>=0.3.0",
29
+ "einops>=0.8.0",
30
+ "tensordict>=0.6.2",
31
+ "torch>=2.3",
32
+ ]
33
+
34
+ [project.urls]
35
+ Homepage = "https://pypi.org/project/titans-pytorch/"
36
+ Repository = "https://github.com/lucidrains/titans-pytorch"
37
+
38
+ [project.optional-dependencies]
39
+ examples = []
40
+ test = [
41
+ "pytest"
42
+ ]
43
+
44
+ [tool.pytest.ini_options]
45
+ pythonpath = [
46
+ "."
47
+ ]
48
+
49
+ [build-system]
50
+ requires = ["hatchling"]
51
+ build-backend = "hatchling.build"
52
+
53
+ [tool.rye]
54
+ managed = true
55
+ dev-dependencies = []
56
+
57
+ [tool.hatch.metadata]
58
+ allow-direct-references = true
59
+
60
+ [tool.hatch.build.targets.wheel]
61
+ packages = ["titans_pytorch"]
@@ -0,0 +1,3 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory
3
+ )
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+
8
+ # taken from S5-pytorch repository
9
+ # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
+
11
+ # helper functions
12
+
13
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
14
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
+ zeros = ((0, 0) * dims_from_right)
16
+ return F.pad(t, (*zeros, *pad), value = value)
17
+
18
+ # the operator that is needed
19
+
20
+ @torch.jit.script
21
+ def binary_operator(
22
+ a: tuple[Tensor, Tensor],
23
+ b: tuple[Tensor, Tensor]
24
+ ):
25
+ a_i, kv_i = a
26
+ a_j, kv_j = b
27
+ return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
28
+
29
+ # Pytorch impl. of jax.lax.associative_scan
30
+ # made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
31
+
32
+ def associative_scan(
33
+ operator: Callable,
34
+ elems: tuple[Tensor, Tensor]
35
+ ):
36
+ num_elems = int(elems[0].shape[1])
37
+
38
+ if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
39
+ raise ValueError('Array inputs to associative_scan must have the same '
40
+ 'first dimension. (saw: {})'
41
+ .format([elem.shape for elem in elems]))
42
+
43
+ def _scan(elems):
44
+ """Perform scan on `elems`."""
45
+ num_elems = elems[0].shape[1]
46
+
47
+ if num_elems < 2:
48
+ return elems
49
+
50
+ # Combine adjacent pairs of elements.
51
+
52
+ reduced_elems = operator(
53
+ [elem[:, :-1:2] for elem in elems],
54
+ [elem[:, 1::2] for elem in elems])
55
+
56
+ # Recursively compute scan for partially reduced tensors.
57
+
58
+ odd_elems = _scan(reduced_elems)
59
+
60
+ if num_elems % 2 == 0:
61
+ even_elems = operator(
62
+ [e[:, :-1] for e in odd_elems],
63
+ [e[:, 2::2] for e in elems])
64
+ else:
65
+ even_elems = operator(
66
+ odd_elems,
67
+ [e[:, 2::2] for e in elems])
68
+
69
+ # The first element of a scan is the same as the first element
70
+ # of the original `elems`.
71
+
72
+ even_elems = [
73
+ torch.cat([elem[:, :1], result], dim=1)
74
+ for (elem, result) in zip(elems, even_elems)]
75
+
76
+ return list(map(_interleave, even_elems, odd_elems))
77
+
78
+ return _scan(elems)
79
+
80
+ def _interleave(a, b):
81
+ a_axis_len, b_axis_len = a.shape[1], b.shape[1]
82
+ output_axis_len = a_axis_len + b_axis_len
83
+
84
+ if (a_axis_len == (b_axis_len + 1)):
85
+ b = pad_at_dim(b, (0, 1), dim = 1)
86
+
87
+ stacked = torch.stack([a, b], dim=2)
88
+ interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
+
90
+ return interleaved[:, :output_axis_len]
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ import torch.nn.functional as F
7
+ from torch.nn import Linear, Module
8
+ from torch.func import functional_call, vmap, grad_and_value
9
+
10
+ from tensordict import TensorDict
11
+
12
+ from titans_pytorch.associative_scan import (
13
+ associative_scan,
14
+ binary_operator
15
+ )
16
+
17
+ import einx
18
+ from einops import rearrange, pack, unpack
19
+ from einops.layers.torch import Rearrange
20
+
21
+ """
22
+ ein notation:
23
+ b - batch
24
+ n - sequence
25
+ d - feature dimension
26
+ c - intra-chunk
27
+ """
28
+
29
+ # constants
30
+
31
+ LinearNoBias = partial(Linear, bias = False)
32
+
33
+ # functions
34
+
35
+ def exists(v):
36
+ return v is not None
37
+
38
+ def default(v, d):
39
+ return v if exists(v) else d
40
+
41
+ def round_down_multiple(seq, mult):
42
+ return seq // mult * mult
43
+
44
+ def pack_one_with_inverse(t, pattern):
45
+ packed, packed_shape = pack([t], pattern)
46
+
47
+ def inverse(out, inv_pattern = None):
48
+ inv_pattern = default(inv_pattern, pattern)
49
+ return unpack(out, packed_shape, inv_pattern)[0]
50
+
51
+ return packed, inverse
52
+
53
+ # classes
54
+
55
+ class MLP(Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ depth
60
+ ):
61
+ super().__init__()
62
+ self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
63
+
64
+ def forward(
65
+ self,
66
+ x
67
+ ):
68
+ for ind, weight in enumerate(self.weights):
69
+ is_first = ind == 0
70
+
71
+ if not is_first:
72
+ x = F.silu(x)
73
+
74
+ x = x @ weight
75
+
76
+ return x
77
+
78
+ # main neural memory
79
+
80
+ def default_loss_fn(pred, target):
81
+ return (pred - target).pow(2).mean(dim = -1).sum()
82
+
83
+ class NeuralMemory(Module):
84
+ def __init__(
85
+ self,
86
+ dim,
87
+ model: Module | None = None,
88
+ store_memory_loss_fn: Callable = default_loss_fn
89
+ ):
90
+ super().__init__()
91
+
92
+ if not exists(model):
93
+ model = MLP(dim, depth = 4)
94
+
95
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
96
+
97
+ # the memory is the weights of the model
98
+
99
+ self.memory_model = model
100
+
101
+ # prepare function for per sample gradients from model above, using torch.func
102
+
103
+ def forward_and_loss(params, inputs, target):
104
+ pred = functional_call(self.memory_model, params, inputs)
105
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) == v|²
106
+ return loss
107
+
108
+ self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
109
+
110
+ # queries for retrieving from the model
111
+
112
+ self.to_queries = LinearNoBias(dim, dim)
113
+
114
+ # keys and values for storing to the model
115
+
116
+ self.to_keys_values = LinearNoBias(dim, dim * 2)
117
+ self.store_memory_loss_fn = store_memory_loss_fn
118
+
119
+ # learned adaptive learning rate and momentum
120
+ # todo - explore mlp layerwise learned lr / momentum
121
+
122
+ self.to_momentum = LinearNoBias(dim, 1)
123
+ self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
124
+ self.to_decay_factor = nn.Sequential(LinearNoBias(dim, 1), nn.Sigmoid()) # weight decay factor
125
+
126
+ def init_weights_and_momentum(self):
127
+ params = TensorDict(dict(self.memory_model.named_parameters()))
128
+
129
+ init_weights = params.clone().zero_()
130
+ init_momentum = params.clone().zero_()
131
+
132
+ return init_weights, init_momentum
133
+
134
+ def store_memories(
135
+ self,
136
+ seq,
137
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
138
+ ):
139
+
140
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
141
+
142
+ past_state = tuple(TensorDict(d) for d in past_state)
143
+ past_weights, past_momentum = past_state
144
+
145
+ curr_weights = curr_weights + past_weights
146
+
147
+ # pack batch and sequence dimension
148
+
149
+ batch = seq.shape[0]
150
+
151
+ adaptive_lr = self.to_adaptive_step(seq)
152
+ adaptive_momentum = self.to_momentum(seq)
153
+
154
+ decay_factor = self.to_decay_factor(seq)
155
+
156
+ # keys and values
157
+
158
+ seq = rearrange(seq, 'b n d -> (b n) d')
159
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
160
+
161
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
162
+
163
+ grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
164
+
165
+ grads = TensorDict(grads)
166
+
167
+ # restore batch and sequence dimension
168
+
169
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
170
+
171
+ # multiply gradients with learned adaptive step size
172
+
173
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
174
+
175
+ # derive momentum with associative scan - eq (10)
176
+
177
+ next_momentum = TensorDict()
178
+
179
+ for param_name, surprise in surprises.items():
180
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
181
+
182
+ _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
183
+
184
+ momentum = inverse_pack(momentum)
185
+
186
+ next_momentum[param_name] = momentum
187
+
188
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
189
+
190
+ updates = TensorDict()
191
+
192
+ for param_name, momentum in next_momentum.items():
193
+ momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
194
+
195
+ _, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
196
+
197
+ update = inverse_pack(update)
198
+
199
+ updates[param_name] = update
200
+
201
+ # compute the next weight per batch
202
+
203
+ last_update = updates.apply(lambda t: t[:, -1])
204
+
205
+ next_state = (curr_weights + last_update, next_momentum)
206
+
207
+ return updates, next_state, aux_store_loss.mean()
208
+
209
+ def retrieve_memories(
210
+ self,
211
+ seq,
212
+ past_weights: dict[str, Tensor] | None = None,
213
+ ):
214
+ batch = seq.shape[0]
215
+
216
+ # the parameters of the memory model stores the memories of the key / values
217
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
218
+
219
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
220
+
221
+ if exists(past_weights):
222
+ past_weights = TensorDict(past_weights)
223
+ assert past_weights.keys() == curr_weights.keys()
224
+
225
+ curr_weights = curr_weights + past_weights
226
+
227
+ # sequence Float['b n d'] to queries
228
+
229
+ queries = self.to_queries(seq)
230
+
231
+ # fetch values from memory model
232
+
233
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
234
+ queries = rearrange(queries, 'b n d -> (b n) 1 d')
235
+
236
+ # forward functional call
237
+
238
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
239
+
240
+ # reconstitute batch dimension
241
+
242
+ values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
243
+
244
+ return values
245
+
246
+ def forward(
247
+ self,
248
+ seq,
249
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
250
+ return_next_memories = False
251
+ ):
252
+ batch = seq.shape[0]
253
+
254
+ if exists(past_state):
255
+ past_state = tuple(TensorDict(d) for d in past_state)
256
+
257
+ if not exists(past_state):
258
+ past_state = self.init_weights_and_momentum()
259
+
260
+ updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
261
+
262
+ past_weights, _ = past_state
263
+
264
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
265
+
266
+ if not return_next_memories:
267
+ return retrieved
268
+
269
+ return retrieved, next_memories, aux_kv_mse_loss