titans-pytorch 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +5 -4
- titans_pytorch/neural_memory.py +63 -26
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.11.dist-info}/METADATA +2 -2
- titans_pytorch-0.3.11.dist-info/RECORD +9 -0
- titans_pytorch-0.3.9.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.11.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.9.dist-info → titans_pytorch-0.3.11.dist-info}/licenses/LICENSE +0 -0
@@ -62,6 +62,7 @@ from rotary_embedding_torch import RotaryEmbedding
|
|
62
62
|
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
63
63
|
|
64
64
|
from x_transformers.attend import Attend
|
65
|
+
|
65
66
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
66
67
|
|
67
68
|
# proposed neural memory
|
@@ -515,7 +516,7 @@ class MemoryAsContextTransformer(Module):
|
|
515
516
|
|
516
517
|
# hyper conection
|
517
518
|
|
518
|
-
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
519
|
+
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
519
520
|
|
520
521
|
self.layers = ModuleList([])
|
521
522
|
|
@@ -553,7 +554,7 @@ class MemoryAsContextTransformer(Module):
|
|
553
554
|
mem_hyper_conn = None
|
554
555
|
|
555
556
|
if layer in neural_memory_layers:
|
556
|
-
mem_hyper_conn = init_hyper_conn(
|
557
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
557
558
|
|
558
559
|
mem = NeuralMemory(
|
559
560
|
dim = dim,
|
@@ -571,8 +572,8 @@ class MemoryAsContextTransformer(Module):
|
|
571
572
|
self.layers.append(ModuleList([
|
572
573
|
mem_hyper_conn,
|
573
574
|
mem,
|
574
|
-
init_hyper_conn(
|
575
|
-
init_hyper_conn(
|
575
|
+
init_hyper_conn(branch = attn),
|
576
|
+
init_hyper_conn(branch = ff)
|
576
577
|
]))
|
577
578
|
|
578
579
|
self.norm = nn.RMSNorm(dim)
|
titans_pytorch/neural_memory.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from itertools import zip_longest
|
6
7
|
from collections import namedtuple
|
7
8
|
|
8
9
|
import torch
|
@@ -21,21 +22,24 @@ from titans_pytorch.memory_models import(
|
|
21
22
|
)
|
22
23
|
|
23
24
|
import einx
|
24
|
-
from einops import rearrange, repeat, reduce, pack, unpack
|
25
|
+
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
25
26
|
from einops.layers.torch import Rearrange, Reduce
|
26
27
|
|
27
28
|
"""
|
28
29
|
ein notation:
|
29
30
|
b - batch
|
31
|
+
h - heads
|
32
|
+
bh - batch and heads
|
30
33
|
n - sequence
|
31
34
|
d - feature dimension
|
32
35
|
c - intra-chunk
|
33
36
|
w - num memory network weight parameters
|
37
|
+
o - momentum orders
|
34
38
|
"""
|
35
39
|
|
36
40
|
LinearNoBias = partial(Linear, bias = False)
|
37
41
|
|
38
|
-
|
42
|
+
NeuralMemState = namedtuple('NeuralMemState', [
|
39
43
|
'seq_index',
|
40
44
|
'weights',
|
41
45
|
'cache_store_segment',
|
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
|
|
224
228
|
per_head_learned_parameters = True,
|
225
229
|
attn_pool_chunks = False,
|
226
230
|
momentum = True,
|
231
|
+
momentum_order = 1,
|
232
|
+
learned_momentum_combine = False,
|
227
233
|
pre_rmsnorm = True,
|
228
234
|
post_rmsnorm = False,
|
229
235
|
qk_rmsnorm = False,
|
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
|
|
367
373
|
else:
|
368
374
|
self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
|
369
375
|
|
370
|
-
# learned adaptive learning rate
|
371
|
-
|
372
|
-
self.to_momentum = Sequential(
|
373
|
-
nn.Linear(dim, heads),
|
374
|
-
Rearrange('b n h -> (b h) n 1')
|
375
|
-
) if momentum else None
|
376
|
+
# learned adaptive learning rate
|
376
377
|
|
377
378
|
self.to_adaptive_step = Sequential(
|
378
379
|
nn.Linear(dim, heads),
|
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
|
|
384
385
|
|
385
386
|
self.adaptive_step_transform = adaptive_step_transform
|
386
387
|
|
388
|
+
# momentum related
|
389
|
+
|
390
|
+
self.to_momentum = Sequential(
|
391
|
+
nn.Linear(dim, heads * momentum_order),
|
392
|
+
Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
|
393
|
+
) if momentum else None
|
394
|
+
|
395
|
+
self.momentum_order = momentum_order
|
396
|
+
self.to_learned_momentum_combine = None
|
397
|
+
|
398
|
+
if learned_momentum_combine:
|
399
|
+
assert momentum
|
400
|
+
assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
|
401
|
+
|
402
|
+
self.to_learned_momentum_combine = Sequential(
|
403
|
+
nn.Linear(dim, heads * momentum_order),
|
404
|
+
nn.Softmax(dim = -1),
|
405
|
+
Rearrange('b n (h o) -> o (b h) n', h = heads)
|
406
|
+
)
|
407
|
+
|
387
408
|
# per layer learning rate modulation
|
388
409
|
|
389
410
|
self.to_layer_modulation = Sequential(
|
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
|
|
463
484
|
zeros = self.memory_model_parameter_dict.clone().zero_()
|
464
485
|
|
465
486
|
if self.per_head_learned_parameters:
|
466
|
-
zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
|
487
|
+
zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
|
467
488
|
else:
|
468
|
-
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
489
|
+
zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
|
469
490
|
|
470
491
|
return zeros
|
471
492
|
|
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
|
|
518
539
|
if has_momentum:
|
519
540
|
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
520
541
|
|
542
|
+
learned_combine = exists(self.to_learned_momentum_combine)
|
543
|
+
|
544
|
+
if learned_combine:
|
545
|
+
combine_momentums = self.to_learned_momentum_combine(chunked_seq)
|
546
|
+
|
521
547
|
if need_layer_lr_mod:
|
522
548
|
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
523
549
|
|
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
|
|
585
611
|
|
586
612
|
# negative gradients, adaptive lr already applied as loss weight
|
587
613
|
|
588
|
-
surprises = grads.
|
614
|
+
surprises = grads.mul(-1)
|
589
615
|
|
590
616
|
# past states
|
591
617
|
|
@@ -603,7 +629,7 @@ class NeuralMemory(Module):
|
|
603
629
|
|
604
630
|
if num_chunks == 0:
|
605
631
|
updates = rearrange_dict_values(weights, 'bh ... -> bh 1 ...')
|
606
|
-
next_store_state =
|
632
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, past_state, updates)
|
607
633
|
|
608
634
|
output = (updates, next_store_state)
|
609
635
|
|
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
|
|
611
637
|
|
612
638
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
613
639
|
|
614
|
-
next_momentum = TensorDict() if has_momentum else None
|
615
640
|
updates = TensorDict()
|
616
641
|
|
617
642
|
next_last_update = TensorDict()
|
@@ -624,32 +649,44 @@ class NeuralMemory(Module):
|
|
624
649
|
# derive momentum with associative scan - eq (10)
|
625
650
|
|
626
651
|
if has_momentum:
|
652
|
+
momentum = surprise
|
653
|
+
|
654
|
+
momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
|
655
|
+
|
627
656
|
last_momentum = past_last_momentum[param_name]
|
628
|
-
|
629
|
-
momentum
|
630
|
-
|
657
|
+
|
658
|
+
# go from first order momentum all the way to the Nth
|
659
|
+
|
660
|
+
for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
|
661
|
+
momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
|
662
|
+
|
663
|
+
momentums.append(momentum)
|
664
|
+
|
665
|
+
momentums = torch.stack(momentums)
|
666
|
+
|
667
|
+
next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
|
668
|
+
|
669
|
+
if not learned_combine:
|
670
|
+
update = momentums[-1]
|
671
|
+
else:
|
672
|
+
update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
|
631
673
|
|
632
674
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
633
675
|
|
634
676
|
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
635
|
-
next_last_update[param_name] = update[:, -1]
|
636
677
|
|
637
678
|
updates[param_name] = update
|
638
|
-
|
639
|
-
if has_momentum:
|
640
|
-
next_momentum[param_name] = momentum
|
679
|
+
next_last_update[param_name] = update[:, -1]
|
641
680
|
|
642
681
|
# determine next state for the storing of memories
|
643
682
|
|
644
683
|
next_state = (next_last_update, next_last_momentum)
|
645
684
|
|
646
|
-
next_store_state =
|
647
|
-
|
648
|
-
# returns
|
685
|
+
next_store_state = NeuralMemState(next_seq_len_index, weights, remainder, next_state, updates)
|
649
686
|
|
650
|
-
|
687
|
+
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
651
688
|
|
652
|
-
return
|
689
|
+
return updates, next_store_state
|
653
690
|
|
654
691
|
def retrieve_memories(
|
655
692
|
self,
|
@@ -746,7 +783,7 @@ class NeuralMemory(Module):
|
|
746
783
|
self,
|
747
784
|
seq,
|
748
785
|
store_seq = None,
|
749
|
-
state:
|
786
|
+
state: NeuralMemState | None = None,
|
750
787
|
prev_weights = None
|
751
788
|
):
|
752
789
|
if seq.ndim == 2:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.11
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
41
|
+
Requires-Dist: hyper-connections>=0.1.10
|
42
42
|
Requires-Dist: ninja
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=7YglrQaDpKS2hbpBBwx7PmqhJdjyvFEPZDt_QXmnUMM,28878
|
6
|
+
titans_pytorch-0.3.11.dist-info/METADATA,sha256=xAEvavDiCj__5Bl_5UXaG__BycdUB2DzHOud-nwsn1c,6817
|
7
|
+
titans_pytorch-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.11.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.11.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
-
titans_pytorch-0.3.9.dist-info/METADATA,sha256=EUuuqHl8jPUPIG8m7xV5esN_2yDNaPdD-H8qFeDxWGo,6815
|
7
|
-
titans_pytorch-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.9.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|