titans-pytorch 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from titans_pytorch.titans import (
2
+ NeuralMemory
3
+ )
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+ from typing import Callable
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+
8
+ # taken from S5-pytorch repository
9
+ # https://github.com/i404788/s5-pytorch/blob/74e2fdae00b915a62c914bf3615c0b8a4279eb84/s5/jax_compat.py#L51-L134
10
+
11
+ # helper functions
12
+
13
+ def pad_at_dim(t, pad, dim = -1, value = 0.):
14
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
15
+ zeros = ((0, 0) * dims_from_right)
16
+ return F.pad(t, (*zeros, *pad), value = value)
17
+
18
+ # the operator that is needed
19
+
20
+ @torch.jit.script
21
+ def binary_operator(
22
+ a: tuple[Tensor, Tensor],
23
+ b: tuple[Tensor, Tensor]
24
+ ):
25
+ a_i, kv_i = a
26
+ a_j, kv_j = b
27
+ return a_j * a_i, torch.addcmul(kv_j, a_j, kv_i)
28
+
29
+ # Pytorch impl. of jax.lax.associative_scan
30
+ # made specifically for axis of 1 (sequence of tokens for autoregressive modeling)
31
+
32
+ def associative_scan(
33
+ operator: Callable,
34
+ elems: tuple[Tensor, Tensor]
35
+ ):
36
+ num_elems = int(elems[0].shape[1])
37
+
38
+ if not all(int(elem.shape[1]) == num_elems for elem in elems[1:]):
39
+ raise ValueError('Array inputs to associative_scan must have the same '
40
+ 'first dimension. (saw: {})'
41
+ .format([elem.shape for elem in elems]))
42
+
43
+ def _scan(elems):
44
+ """Perform scan on `elems`."""
45
+ num_elems = elems[0].shape[1]
46
+
47
+ if num_elems < 2:
48
+ return elems
49
+
50
+ # Combine adjacent pairs of elements.
51
+
52
+ reduced_elems = operator(
53
+ [elem[:, :-1:2] for elem in elems],
54
+ [elem[:, 1::2] for elem in elems])
55
+
56
+ # Recursively compute scan for partially reduced tensors.
57
+
58
+ odd_elems = _scan(reduced_elems)
59
+
60
+ if num_elems % 2 == 0:
61
+ even_elems = operator(
62
+ [e[:, :-1] for e in odd_elems],
63
+ [e[:, 2::2] for e in elems])
64
+ else:
65
+ even_elems = operator(
66
+ odd_elems,
67
+ [e[:, 2::2] for e in elems])
68
+
69
+ # The first element of a scan is the same as the first element
70
+ # of the original `elems`.
71
+
72
+ even_elems = [
73
+ torch.cat([elem[:, :1], result], dim=1)
74
+ for (elem, result) in zip(elems, even_elems)]
75
+
76
+ return list(map(_interleave, even_elems, odd_elems))
77
+
78
+ return _scan(elems)
79
+
80
+ def _interleave(a, b):
81
+ a_axis_len, b_axis_len = a.shape[1], b.shape[1]
82
+ output_axis_len = a_axis_len + b_axis_len
83
+
84
+ if (a_axis_len == (b_axis_len + 1)):
85
+ b = pad_at_dim(b, (0, 1), dim = 1)
86
+
87
+ stacked = torch.stack([a, b], dim=2)
88
+ interleaved = torch.flatten(stacked, start_dim=1, end_dim=2)
89
+
90
+ return interleaved[:, :output_axis_len]
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ import torch.nn.functional as F
7
+ from torch.nn import Linear, Module
8
+ from torch.func import functional_call, vmap, grad_and_value
9
+
10
+ from tensordict import TensorDict
11
+
12
+ from titans_pytorch.associative_scan import (
13
+ associative_scan,
14
+ binary_operator
15
+ )
16
+
17
+ import einx
18
+ from einops import rearrange, pack, unpack
19
+ from einops.layers.torch import Rearrange
20
+
21
+ """
22
+ ein notation:
23
+ b - batch
24
+ n - sequence
25
+ d - feature dimension
26
+ c - intra-chunk
27
+ """
28
+
29
+ # constants
30
+
31
+ LinearNoBias = partial(Linear, bias = False)
32
+
33
+ # functions
34
+
35
+ def exists(v):
36
+ return v is not None
37
+
38
+ def default(v, d):
39
+ return v if exists(v) else d
40
+
41
+ def round_down_multiple(seq, mult):
42
+ return seq // mult * mult
43
+
44
+ def pack_one_with_inverse(t, pattern):
45
+ packed, packed_shape = pack([t], pattern)
46
+
47
+ def inverse(out, inv_pattern = None):
48
+ inv_pattern = default(inv_pattern, pattern)
49
+ return unpack(out, packed_shape, inv_pattern)[0]
50
+
51
+ return packed, inverse
52
+
53
+ # classes
54
+
55
+ class MLP(Module):
56
+ def __init__(
57
+ self,
58
+ dim,
59
+ depth
60
+ ):
61
+ super().__init__()
62
+ self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
63
+
64
+ def forward(
65
+ self,
66
+ x
67
+ ):
68
+ for ind, weight in enumerate(self.weights):
69
+ is_first = ind == 0
70
+
71
+ if not is_first:
72
+ x = F.silu(x)
73
+
74
+ x = x @ weight
75
+
76
+ return x
77
+
78
+ # main neural memory
79
+
80
+ def default_loss_fn(pred, target):
81
+ return (pred - target).pow(2).mean(dim = -1).sum()
82
+
83
+ class NeuralMemory(Module):
84
+ def __init__(
85
+ self,
86
+ dim,
87
+ model: Module | None = None,
88
+ store_memory_loss_fn: Callable = default_loss_fn
89
+ ):
90
+ super().__init__()
91
+
92
+ if not exists(model):
93
+ model = MLP(dim, depth = 4)
94
+
95
+ assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
96
+
97
+ # the memory is the weights of the model
98
+
99
+ self.memory_model = model
100
+
101
+ # prepare function for per sample gradients from model above, using torch.func
102
+
103
+ def forward_and_loss(params, inputs, target):
104
+ pred = functional_call(self.memory_model, params, inputs)
105
+ loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) == v|²
106
+ return loss
107
+
108
+ self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
109
+
110
+ # queries for retrieving from the model
111
+
112
+ self.to_queries = LinearNoBias(dim, dim)
113
+
114
+ # keys and values for storing to the model
115
+
116
+ self.to_keys_values = LinearNoBias(dim, dim * 2)
117
+ self.store_memory_loss_fn = store_memory_loss_fn
118
+
119
+ # learned adaptive learning rate and momentum
120
+ # todo - explore mlp layerwise learned lr / momentum
121
+
122
+ self.to_momentum = LinearNoBias(dim, 1)
123
+ self.to_adaptive_step = nn.Sequential(LinearNoBias(dim, 1), Rearrange('... 1 -> ...'))
124
+ self.to_decay_factor = nn.Sequential(LinearNoBias(dim, 1), nn.Sigmoid()) # weight decay factor
125
+
126
+ def init_weights_and_momentum(self):
127
+ params = TensorDict(dict(self.memory_model.named_parameters()))
128
+
129
+ init_weights = params.clone().zero_()
130
+ init_momentum = params.clone().zero_()
131
+
132
+ return init_weights, init_momentum
133
+
134
+ def store_memories(
135
+ self,
136
+ seq,
137
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
138
+ ):
139
+
140
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
141
+
142
+ past_state = tuple(TensorDict(d) for d in past_state)
143
+ past_weights, past_momentum = past_state
144
+
145
+ curr_weights = curr_weights + past_weights
146
+
147
+ # pack batch and sequence dimension
148
+
149
+ batch = seq.shape[0]
150
+
151
+ adaptive_lr = self.to_adaptive_step(seq)
152
+ adaptive_momentum = self.to_momentum(seq)
153
+
154
+ decay_factor = self.to_decay_factor(seq)
155
+
156
+ # keys and values
157
+
158
+ seq = rearrange(seq, 'b n d -> (b n) d')
159
+ keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
160
+
161
+ # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
162
+
163
+ grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
164
+
165
+ grads = TensorDict(grads)
166
+
167
+ # restore batch and sequence dimension
168
+
169
+ grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
170
+
171
+ # multiply gradients with learned adaptive step size
172
+
173
+ surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
174
+
175
+ # derive momentum with associative scan - eq (10)
176
+
177
+ next_momentum = TensorDict()
178
+
179
+ for param_name, surprise in surprises.items():
180
+ surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
181
+
182
+ _, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
183
+
184
+ momentum = inverse_pack(momentum)
185
+
186
+ next_momentum[param_name] = momentum
187
+
188
+ # use associative scan again for learned forgetting (weight decay) - eq (13)
189
+
190
+ updates = TensorDict()
191
+
192
+ for param_name, momentum in next_momentum.items():
193
+ momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
194
+
195
+ _, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
196
+
197
+ update = inverse_pack(update)
198
+
199
+ updates[param_name] = update
200
+
201
+ # compute the next weight per batch
202
+
203
+ last_update = updates.apply(lambda t: t[:, -1])
204
+
205
+ next_state = (curr_weights + last_update, next_momentum)
206
+
207
+ return updates, next_state, aux_store_loss.mean()
208
+
209
+ def retrieve_memories(
210
+ self,
211
+ seq,
212
+ past_weights: dict[str, Tensor] | None = None,
213
+ ):
214
+ batch = seq.shape[0]
215
+
216
+ # the parameters of the memory model stores the memories of the key / values
217
+ # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
218
+
219
+ curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
220
+
221
+ if exists(past_weights):
222
+ past_weights = TensorDict(past_weights)
223
+ assert past_weights.keys() == curr_weights.keys()
224
+
225
+ curr_weights = curr_weights + past_weights
226
+
227
+ # sequence Float['b n d'] to queries
228
+
229
+ queries = self.to_queries(seq)
230
+
231
+ # fetch values from memory model
232
+
233
+ curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
234
+ queries = rearrange(queries, 'b n d -> (b n) 1 d')
235
+
236
+ # forward functional call
237
+
238
+ values = functional_call(self.memory_model, dict(curr_weights), queries)
239
+
240
+ # reconstitute batch dimension
241
+
242
+ values = rearrange(values, '(b n) 1 d -> b n d', b = batch)
243
+
244
+ return values
245
+
246
+ def forward(
247
+ self,
248
+ seq,
249
+ past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
250
+ return_next_memories = False
251
+ ):
252
+ batch = seq.shape[0]
253
+
254
+ if exists(past_state):
255
+ past_state = tuple(TensorDict(d) for d in past_state)
256
+
257
+ if not exists(past_state):
258
+ past_state = self.init_weights_and_momentum()
259
+
260
+ updates, next_memories, aux_kv_mse_loss = self.store_memories(seq, past_state)
261
+
262
+ past_weights, _ = past_state
263
+
264
+ retrieved = self.retrieve_memories(seq, past_weights + updates)
265
+
266
+ if not return_next_memories:
267
+ return retrieved
268
+
269
+ return retrieved, next_memories, aux_kv_mse_loss
@@ -0,0 +1,84 @@
1
+ Metadata-Version: 2.4
2
+ Name: titans-pytorch
3
+ Version: 0.0.1
4
+ Summary: Titans
5
+ Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
+ Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
7
+ Author-email: Phil Wang <lucidrains@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) 2025 Phil Wang
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: artificial intelligence,deep learning,linear attention,neural memory module,test time training
31
+ Classifier: Development Status :: 4 - Beta
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: MIT License
34
+ Classifier: Programming Language :: Python :: 3.9
35
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
+ Requires-Python: >=3.9
37
+ Requires-Dist: einops>=0.8.0
38
+ Requires-Dist: einx>=0.3.0
39
+ Requires-Dist: tensordict>=0.6.2
40
+ Requires-Dist: torch>=2.3
41
+ Provides-Extra: examples
42
+ Provides-Extra: test
43
+ Requires-Dist: pytest; extra == 'test'
44
+ Description-Content-Type: text/markdown
45
+
46
+ <img src="./fig2.png" width="400px"></img>
47
+
48
+ <img src="./fig1.png" width="400px"></img>
49
+
50
+ ## Titans - Pytorch (wip)
51
+
52
+ Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
53
+
54
+ ## Install
55
+
56
+ ```bash
57
+ $ pip install titans-pytorch
58
+ ```
59
+
60
+ ## Usage
61
+
62
+ ```python
63
+ import torch
64
+ from titans_pytorch import NeuralMemory
65
+
66
+ x = torch.randn(2, 64, 32)
67
+
68
+ mem = NeuralMemory(32)
69
+
70
+ out = mem(x)
71
+
72
+ assert x.shape == out.shape
73
+ ```
74
+
75
+ ## Citations
76
+
77
+ ```bibtex
78
+ @inproceedings{Behrouz2024TitansLT,
79
+ title = {Titans: Learning to Memorize at Test Time},
80
+ author = {Ali Behrouz and Peilin Zhong and Vahab S. Mirrokni},
81
+ year = {2024},
82
+ url = {https://api.semanticscholar.org/CorpusID:275212078}
83
+ }
84
+ ```
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=xty74Q3xQ174uycscfOnh-zgxGMH882lrIA_KGvxTUU,7802
4
+ titans_pytorch-0.0.1.dist-info/METADATA,sha256=HqR3VxpV5e-dPLLEbuOekC161-2r2WwKBCvK7E2MhAs,2968
5
+ titans_pytorch-0.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.