titans-pytorch 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +72 -25
- {titans_pytorch-0.0.3.dist-info → titans_pytorch-0.0.5.dist-info}/METADATA +6 -3
- titans_pytorch-0.0.5.dist-info/RECORD +7 -0
- titans_pytorch-0.0.3.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.3.dist-info → titans_pytorch-0.0.5.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.3.dist-info → titans_pytorch-0.0.5.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
import math
|
|
2
3
|
from functools import partial
|
|
3
4
|
|
|
4
5
|
import torch
|
|
@@ -11,12 +12,13 @@ from tensordict import TensorDict
|
|
|
11
12
|
|
|
12
13
|
from titans_pytorch.associative_scan import (
|
|
13
14
|
associative_scan,
|
|
14
|
-
binary_operator
|
|
15
|
+
binary_operator,
|
|
16
|
+
pad_at_dim
|
|
15
17
|
)
|
|
16
18
|
|
|
17
19
|
import einx
|
|
18
20
|
from einops import rearrange, pack, unpack
|
|
19
|
-
from einops.layers.torch import Rearrange
|
|
21
|
+
from einops.layers.torch import Rearrange, Reduce
|
|
20
22
|
|
|
21
23
|
"""
|
|
22
24
|
ein notation:
|
|
@@ -41,6 +43,9 @@ def default(v, d):
|
|
|
41
43
|
def round_down_multiple(seq, mult):
|
|
42
44
|
return seq // mult * mult
|
|
43
45
|
|
|
46
|
+
def round_up_multiple(seq, mult):
|
|
47
|
+
return math.ceil(seq / mult) * mult
|
|
48
|
+
|
|
44
49
|
def pack_one_with_inverse(t, pattern):
|
|
45
50
|
packed, packed_shape = pack([t], pattern)
|
|
46
51
|
|
|
@@ -84,6 +89,7 @@ class NeuralMemory(Module):
|
|
|
84
89
|
def __init__(
|
|
85
90
|
self,
|
|
86
91
|
dim,
|
|
92
|
+
chunk_size = 1,
|
|
87
93
|
model: Module | None = None,
|
|
88
94
|
store_memory_loss_fn: Callable = default_loss_fn
|
|
89
95
|
):
|
|
@@ -98,11 +104,15 @@ class NeuralMemory(Module):
|
|
|
98
104
|
|
|
99
105
|
self.memory_model = model
|
|
100
106
|
|
|
107
|
+
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
108
|
+
|
|
109
|
+
self.chunk_size = chunk_size
|
|
110
|
+
|
|
101
111
|
# prepare function for per sample gradients from model above, using torch.func
|
|
102
112
|
|
|
103
113
|
def forward_and_loss(params, inputs, target):
|
|
104
114
|
pred = functional_call(self.memory_model, params, inputs)
|
|
105
|
-
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k)
|
|
115
|
+
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
106
116
|
return loss
|
|
107
117
|
|
|
108
118
|
self.per_sample_grad_and_value_fn = vmap(grad_and_value(forward_and_loss), in_dims = (None, 0, 0))
|
|
@@ -119,9 +129,23 @@ class NeuralMemory(Module):
|
|
|
119
129
|
# learned adaptive learning rate and momentum
|
|
120
130
|
# todo - explore mlp layerwise learned lr / momentum
|
|
121
131
|
|
|
122
|
-
self.to_momentum =
|
|
123
|
-
|
|
124
|
-
|
|
132
|
+
self.to_momentum = nn.Sequential(
|
|
133
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
134
|
+
LinearNoBias(dim, 1)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self.to_adaptive_step = nn.Sequential(
|
|
138
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
139
|
+
LinearNoBias(dim, 1),
|
|
140
|
+
Rearrange('... 1 -> ...')
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# weight decay factor
|
|
144
|
+
|
|
145
|
+
self.to_decay_factor = nn.Sequential(
|
|
146
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
147
|
+
LinearNoBias(dim, 1)
|
|
148
|
+
)
|
|
125
149
|
|
|
126
150
|
def init_weights_and_momentum(self):
|
|
127
151
|
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
@@ -137,6 +161,16 @@ class NeuralMemory(Module):
|
|
|
137
161
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
138
162
|
):
|
|
139
163
|
|
|
164
|
+
# curtail sequence by multiple of the chunk size
|
|
165
|
+
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
166
|
+
|
|
167
|
+
seq_len = seq.shape[-2]
|
|
168
|
+
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
169
|
+
|
|
170
|
+
seq = seq[:, :round_down_seq_len]
|
|
171
|
+
|
|
172
|
+
# curr weights + past weights, in the case that the initial weights are learned
|
|
173
|
+
|
|
140
174
|
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
141
175
|
|
|
142
176
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -148,16 +182,19 @@ class NeuralMemory(Module):
|
|
|
148
182
|
|
|
149
183
|
batch = seq.shape[0]
|
|
150
184
|
|
|
151
|
-
adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
|
|
185
|
+
adaptive_lr = self.to_adaptive_step(seq).tanh() * 0.5 + 0.5
|
|
152
186
|
|
|
153
187
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
154
188
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
155
189
|
|
|
156
190
|
# keys and values
|
|
157
191
|
|
|
158
|
-
seq = rearrange(seq, 'b n d -> (b n) d')
|
|
159
192
|
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
160
193
|
|
|
194
|
+
# take care of chunking
|
|
195
|
+
|
|
196
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
|
197
|
+
|
|
161
198
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
162
199
|
|
|
163
200
|
grads, aux_store_loss = self.per_sample_grad_and_value_fn(dict(curr_weights), keys, values)
|
|
@@ -172,31 +209,24 @@ class NeuralMemory(Module):
|
|
|
172
209
|
|
|
173
210
|
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
|
174
211
|
|
|
175
|
-
#
|
|
212
|
+
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
176
213
|
|
|
177
214
|
next_momentum = TensorDict()
|
|
215
|
+
updates = TensorDict()
|
|
178
216
|
|
|
179
217
|
for param_name, surprise in surprises.items():
|
|
180
218
|
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
181
219
|
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
momentum = inverse_pack(momentum)
|
|
185
|
-
|
|
186
|
-
next_momentum[param_name] = momentum
|
|
220
|
+
# derive momentum with associative scan - eq (10)
|
|
187
221
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
updates = TensorDict()
|
|
222
|
+
_, momentum = associative_scan(binary_operator, (adaptive_momentum, surprise)) # momentum is S / surprise in the paper
|
|
191
223
|
|
|
192
|
-
|
|
193
|
-
momentum, inverse_pack = pack_one_with_inverse(momentum, 'b n *')
|
|
224
|
+
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
194
225
|
|
|
195
226
|
_, update = associative_scan(binary_operator, (1. - decay_factor, momentum)) # momentum is S / surprise in the paper
|
|
196
227
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
updates[param_name] = update
|
|
228
|
+
updates[param_name] = inverse_pack(update)
|
|
229
|
+
next_momentum[param_name] = inverse_pack(momentum)
|
|
200
230
|
|
|
201
231
|
# compute the next weight per batch
|
|
202
232
|
|
|
@@ -211,7 +241,19 @@ class NeuralMemory(Module):
|
|
|
211
241
|
seq,
|
|
212
242
|
past_weights: dict[str, Tensor] | None = None,
|
|
213
243
|
):
|
|
214
|
-
|
|
244
|
+
chunk_size = self.chunk_size
|
|
245
|
+
batch, seq_len = seq.shape[:2]
|
|
246
|
+
|
|
247
|
+
assert seq_len >= chunk_size
|
|
248
|
+
|
|
249
|
+
seq = seq[:, (chunk_size - 1):]
|
|
250
|
+
curtailed_seq_len = seq.shape[-2]
|
|
251
|
+
|
|
252
|
+
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
253
|
+
|
|
254
|
+
padding = next_seq_len - curtailed_seq_len
|
|
255
|
+
|
|
256
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
215
257
|
|
|
216
258
|
# the parameters of the memory model stores the memories of the key / values
|
|
217
259
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -231,7 +273,7 @@ class NeuralMemory(Module):
|
|
|
231
273
|
# fetch values from memory model
|
|
232
274
|
|
|
233
275
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
|
234
|
-
queries = rearrange(queries, 'b n d -> (b n)
|
|
276
|
+
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
|
235
277
|
|
|
236
278
|
# forward functional call
|
|
237
279
|
|
|
@@ -239,7 +281,12 @@ class NeuralMemory(Module):
|
|
|
239
281
|
|
|
240
282
|
# reconstitute batch dimension
|
|
241
283
|
|
|
242
|
-
values = rearrange(values, '(b n)
|
|
284
|
+
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
|
285
|
+
|
|
286
|
+
# restore
|
|
287
|
+
|
|
288
|
+
values = pad_at_dim(values, (chunk_size - 1, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
|
289
|
+
values = values[:, :-padding]
|
|
243
290
|
|
|
244
291
|
return values
|
|
245
292
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -49,7 +49,7 @@ Description-Content-Type: text/markdown
|
|
|
49
49
|
|
|
50
50
|
## Titans - Pytorch (wip)
|
|
51
51
|
|
|
52
|
-
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module.
|
|
52
|
+
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
53
53
|
|
|
54
54
|
## Install
|
|
55
55
|
|
|
@@ -65,7 +65,10 @@ from titans_pytorch import NeuralMemory
|
|
|
65
65
|
|
|
66
66
|
x = torch.randn(2, 64, 32)
|
|
67
67
|
|
|
68
|
-
mem = NeuralMemory(
|
|
68
|
+
mem = NeuralMemory(
|
|
69
|
+
dim = 32,
|
|
70
|
+
chunk_size = 2
|
|
71
|
+
)
|
|
69
72
|
|
|
70
73
|
out = mem(x)
|
|
71
74
|
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/titans.py,sha256=3Mewuysj0g7iAlfjdqMlJhn9-pKJuOerB1frQmQYXuc,9428
|
|
4
|
+
titans_pytorch-0.0.5.dist-info/METADATA,sha256=f1DgCKZz9nqNfZOrqbOpyn-yEx2v5M5zgGIW0Zeu84I,3032
|
|
5
|
+
titans_pytorch-0.0.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
+
titans_pytorch-0.0.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
+
titans_pytorch-0.0.5.dist-info/RECORD,,
|
|
@@ -1,7 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=QKuJPCOJCdgtaPeKoHEkYkiQe65_LV9_8-cIMbBPU30,55
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/titans.py,sha256=0Mh9LJv5hLVbB2MvRJX5QanAeTtU9LAuj6YOQUwsyUQ,7813
|
|
4
|
-
titans_pytorch-0.0.3.dist-info/METADATA,sha256=AXfDl_MTIu24VRagi_rgiH8rHXFBU5euwSD6DMwLgsg,2968
|
|
5
|
-
titans_pytorch-0.0.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
6
|
-
titans_pytorch-0.0.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
7
|
-
titans_pytorch-0.0.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|