titans-pytorch 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/titans.py +23 -10
- {titans_pytorch-0.0.21.dist-info → titans_pytorch-0.0.23.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.23.dist-info/RECORD +8 -0
- titans_pytorch-0.0.21.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.21.dist-info → titans_pytorch-0.0.23.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.21.dist-info → titans_pytorch-0.0.23.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
|
|
100
100
|
|
101
101
|
# main neural memory
|
102
102
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step):
|
104
|
-
return
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
104
|
+
return adaptive_step.sigmoid() * max_lr
|
105
105
|
|
106
106
|
def default_loss_fn(pred, target):
|
107
|
-
return (pred - target).pow(2).mean(dim = -1)
|
107
|
+
return (pred - target).pow(2).mean(dim = -1)
|
108
108
|
|
109
109
|
class NeuralMemory(Module):
|
110
110
|
def __init__(
|
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
|
|
142
142
|
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
143
143
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
144
|
|
145
|
+
self.retrieve_gate = nn.Sequential(
|
146
|
+
LinearNoBias(dim, heads),
|
147
|
+
Rearrange('b n h -> (b h) n 1'),
|
148
|
+
nn.Sigmoid()
|
149
|
+
) if heads > 1 else None
|
150
|
+
|
145
151
|
# memory mlp
|
146
152
|
|
147
153
|
if not exists(model):
|
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
|
|
159
165
|
|
160
166
|
# prepare function for per sample gradients from model above, using torch.func
|
161
167
|
|
162
|
-
def forward_and_loss(params, inputs, target):
|
168
|
+
def forward_and_loss(params, inputs, loss_weights, target):
|
163
169
|
pred = functional_call(self.memory_model, params, inputs)
|
164
170
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
165
|
-
|
171
|
+
loss = loss * loss_weights
|
172
|
+
return loss.sum()
|
166
173
|
|
167
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
174
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
168
175
|
|
169
176
|
# queries for retrieving from the model
|
170
177
|
|
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
|
|
190
197
|
)
|
191
198
|
|
192
199
|
self.to_adaptive_step = nn.Sequential(
|
193
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
194
200
|
LinearNoBias(dim, heads),
|
195
201
|
Rearrange('b n h -> (b h) n')
|
196
202
|
)
|
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
|
|
271
277
|
|
272
278
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
273
279
|
|
280
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
|
281
|
+
|
274
282
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
275
283
|
|
276
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
284
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
277
285
|
|
278
286
|
grads = TensorDict(grads)
|
279
287
|
|
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
|
|
286
294
|
|
287
295
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
288
296
|
|
289
|
-
#
|
297
|
+
# negative gradients, adaptive lr already applied as loss weight
|
290
298
|
|
291
|
-
surprises = grads.apply(lambda t:
|
299
|
+
surprises = grads.apply(lambda t: -t)
|
292
300
|
|
293
301
|
# determine scan function
|
294
302
|
|
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
|
|
405
413
|
|
406
414
|
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
407
415
|
|
416
|
+
# maybe gate
|
417
|
+
|
418
|
+
if exists(self.retrieve_gate):
|
419
|
+
values = values * self.retrieve_gate(seq)
|
420
|
+
|
408
421
|
# maybe merge heads and combine
|
409
422
|
|
410
423
|
values = self.merge_heads(values)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=AU2mf3RkClSRIP0IUwnnqsA5O1udNYGbTRb0lVBLA78,14024
|
4
|
+
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
+
titans_pytorch-0.0.23.dist-info/METADATA,sha256=U80_8U_mwaQqwWKPWlu76-O3-CTCfrZ7t_HdE3Zl_qE,3811
|
6
|
+
titans_pytorch-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.0.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.0.23.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=rsgpGs-nnqvsvF6Mu2GfmfNlRgYfHAV9MCEmd1ohUUI,13687
|
4
|
-
titans_pytorch/titans_attn_memory.py,sha256=Rwx_-riGeISBefZg5Kjic8jzmmRRys-u93D2Kgb7Mos,12691
|
5
|
-
titans_pytorch-0.0.21.dist-info/METADATA,sha256=3z2LGviHUicQq4_ThlRasOUCDXIe0xX2aIhxtKLWI0Q,3811
|
6
|
-
titans_pytorch-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.0.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.0.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|