titans-pytorch 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +57 -18
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.10.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.10.dist-info/RECORD +9 -0
- titans_pytorch-0.3.9.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.10.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.10.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from itertools import zip_longest
|
6
7
|
from collections import namedtuple
|
7
8
|
|
8
9
|
import torch
|
@@ -21,16 +22,19 @@ from titans_pytorch.memory_models import(
|
|
21
22
|
)
|
22
23
|
|
23
24
|
import einx
|
24
|
-
from einops import rearrange, repeat, reduce, pack, unpack
|
25
|
+
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
25
26
|
from einops.layers.torch import Rearrange, Reduce
|
26
27
|
|
27
28
|
"""
|
28
29
|
ein notation:
|
29
30
|
b - batch
|
31
|
+
h - heads
|
32
|
+
bh - batch and heads
|
30
33
|
n - sequence
|
31
34
|
d - feature dimension
|
32
35
|
c - intra-chunk
|
33
36
|
w - num memory network weight parameters
|
37
|
+
o - momentum orders
|
34
38
|
"""
|
35
39
|
|
36
40
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
|
|
224
228
|
per_head_learned_parameters = True,
|
225
229
|
attn_pool_chunks = False,
|
226
230
|
momentum = True,
|
231
|
+
momentum_order = 1,
|
232
|
+
learned_momentum_combine = False,
|
227
233
|
pre_rmsnorm = True,
|
228
234
|
post_rmsnorm = False,
|
229
235
|
qk_rmsnorm = False,
|
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
|
|
367
373
|
else:
|
368
374
|
self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
|
369
375
|
|
370
|
-
# learned adaptive learning rate
|
371
|
-
|
372
|
-
self.to_momentum = Sequential(
|
373
|
-
nn.Linear(dim, heads),
|
374
|
-
Rearrange('b n h -> (b h) n 1')
|
375
|
-
) if momentum else None
|
376
|
+
# learned adaptive learning rate
|
376
377
|
|
377
378
|
self.to_adaptive_step = Sequential(
|
378
379
|
nn.Linear(dim, heads),
|
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
|
|
384
385
|
|
385
386
|
self.adaptive_step_transform = adaptive_step_transform
|
386
387
|
|
388
|
+
# momentum related
|
389
|
+
|
390
|
+
self.to_momentum = Sequential(
|
391
|
+
nn.Linear(dim, heads * momentum_order),
|
392
|
+
Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
|
393
|
+
) if momentum else None
|
394
|
+
|
395
|
+
self.momentum_order = momentum_order
|
396
|
+
self.to_learned_momentum_combine = None
|
397
|
+
|
398
|
+
if learned_momentum_combine:
|
399
|
+
assert momentum
|
400
|
+
assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
|
401
|
+
|
402
|
+
self.to_learned_momentum_combine = Sequential(
|
403
|
+
nn.Linear(dim, heads * momentum_order),
|
404
|
+
nn.Softmax(dim = -1),
|
405
|
+
Rearrange('b n (h o) -> o (b h) n', h = heads)
|
406
|
+
)
|
407
|
+
|
387
408
|
# per layer learning rate modulation
|
388
409
|
|
389
410
|
self.to_layer_modulation = Sequential(
|
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
|
|
463
484
|
zeros = self.memory_model_parameter_dict.clone().zero_()
|
464
485
|
|
465
486
|
if self.per_head_learned_parameters:
|
466
|
-
zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
|
487
|
+
zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
|
467
488
|
else:
|
468
|
-
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
489
|
+
zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
|
469
490
|
|
470
491
|
return zeros
|
471
492
|
|
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
|
|
518
539
|
if has_momentum:
|
519
540
|
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
520
541
|
|
542
|
+
learned_combine = exists(self.to_learned_momentum_combine)
|
543
|
+
|
544
|
+
if learned_combine:
|
545
|
+
combine_momentums = self.to_learned_momentum_combine(chunked_seq)
|
546
|
+
|
521
547
|
if need_layer_lr_mod:
|
522
548
|
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
523
549
|
|
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
|
|
585
611
|
|
586
612
|
# negative gradients, adaptive lr already applied as loss weight
|
587
613
|
|
588
|
-
surprises = grads.
|
614
|
+
surprises = grads.mul(-1)
|
589
615
|
|
590
616
|
# past states
|
591
617
|
|
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
|
|
611
637
|
|
612
638
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
613
639
|
|
614
|
-
next_momentum = TensorDict() if has_momentum else None
|
615
640
|
updates = TensorDict()
|
616
641
|
|
617
642
|
next_last_update = TensorDict()
|
@@ -624,20 +649,34 @@ class NeuralMemory(Module):
|
|
624
649
|
# derive momentum with associative scan - eq (10)
|
625
650
|
|
626
651
|
if has_momentum:
|
652
|
+
momentum = surprise
|
653
|
+
|
654
|
+
momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
|
655
|
+
|
627
656
|
last_momentum = past_last_momentum[param_name]
|
628
|
-
|
629
|
-
momentum
|
630
|
-
|
657
|
+
|
658
|
+
# go from first order momentum all the way to the Nth
|
659
|
+
|
660
|
+
for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
|
661
|
+
momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
|
662
|
+
|
663
|
+
momentums.append(momentum)
|
664
|
+
|
665
|
+
momentums = torch.stack(momentums)
|
666
|
+
|
667
|
+
next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
|
668
|
+
|
669
|
+
if not learned_combine:
|
670
|
+
update = momentums[-1]
|
671
|
+
else:
|
672
|
+
update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
|
631
673
|
|
632
674
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
633
675
|
|
634
676
|
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
635
|
-
next_last_update[param_name] = update[:, -1]
|
636
677
|
|
637
678
|
updates[param_name] = update
|
638
|
-
|
639
|
-
if has_momentum:
|
640
|
-
next_momentum[param_name] = momentum
|
679
|
+
next_last_update[param_name] = update[:, -1]
|
641
680
|
|
642
681
|
# determine next state for the storing of memories
|
643
682
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=BeOnq41gjZeq-XJFjkHE44F9dLzsg9mm36EBYZ4wHMA,28814
|
6
|
+
titans_pytorch-0.3.10.dist-info/METADATA,sha256=sA_Dx_x5RMcpz5-vUPDHuz__tHYfKzs4W_BgY4CHPdk,6816
|
7
|
+
titans_pytorch-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.10.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.10.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
-
titans_pytorch-0.3.9.dist-info/METADATA,sha256=EUuuqHl8jPUPIG8m7xV5esN_2yDNaPdD-H8qFeDxWGo,6815
|
7
|
-
titans_pytorch-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|