titans-pytorch 0.1.33__tar.gz → 0.1.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/PKG-INFO +1 -1
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/pyproject.toml +1 -1
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/tests/test_titans.py +35 -3
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/titans_pytorch/mac_transformer.py +46 -7
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/titans_pytorch/neural_memory.py +19 -3
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/train_mac.py +14 -6
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/.gitignore +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/LICENSE +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/README.md +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/data/README.md +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/fig1.png +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/fig2.png +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.33 → titans_pytorch-0.1.35}/titans_pytorch/associative_scan.py +0 -0
@@ -86,25 +86,32 @@ def test_retrieve_store_diff_seq():
|
|
86
86
|
assert retrieve_seq.shape == retrieved.shape
|
87
87
|
|
88
88
|
def test_weight_tied_mlp_neural_mem():
|
89
|
+
from titans_pytorch import MemoryMLP
|
90
|
+
|
91
|
+
mlp = MemoryMLP(64, depth = 2)
|
92
|
+
|
89
93
|
mem = NeuralMemory(
|
90
94
|
dim = 384,
|
91
95
|
dim_head = 64,
|
92
96
|
heads = 2,
|
93
|
-
chunk_size = 2
|
97
|
+
chunk_size = 2,
|
98
|
+
model = mlp
|
94
99
|
)
|
95
100
|
|
96
101
|
mem2 = NeuralMemory(
|
97
102
|
dim = 384,
|
98
103
|
dim_head = 64,
|
99
104
|
heads = 2,
|
100
|
-
chunk_size = 2
|
105
|
+
chunk_size = 2,
|
106
|
+
model = mlp
|
101
107
|
)
|
102
108
|
|
103
109
|
mem3 = NeuralMemory(
|
104
110
|
dim = 384,
|
105
111
|
dim_head = 64,
|
106
112
|
heads = 2,
|
107
|
-
chunk_size = 2
|
113
|
+
chunk_size = 2,
|
114
|
+
model = mlp
|
108
115
|
)
|
109
116
|
|
110
117
|
seq = torch.randn(2, 128, 384)
|
@@ -113,6 +120,31 @@ def test_weight_tied_mlp_neural_mem():
|
|
113
120
|
seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
|
114
121
|
seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
|
115
122
|
|
123
|
+
def test_mac_with_weight_tied_neural_mem():
|
124
|
+
from titans_pytorch import MemoryMLP, MemoryAsContextTransformer
|
125
|
+
|
126
|
+
transformer = MemoryAsContextTransformer(
|
127
|
+
num_tokens = 256,
|
128
|
+
dim = 256,
|
129
|
+
depth = 2,
|
130
|
+
segment_len = 2,
|
131
|
+
num_persist_mem_tokens = 0,
|
132
|
+
num_longterm_mem_tokens = 2,
|
133
|
+
neural_memory_segment_len = 2,
|
134
|
+
sliding_window_attn = True,
|
135
|
+
neural_memory_layers = (1, 2),
|
136
|
+
neural_memory_model = MemoryMLP(256, depth = 1),
|
137
|
+
num_residual_streams = 4,
|
138
|
+
weight_tie_memory_model = True,
|
139
|
+
neural_mem_gate_attn_output = True,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
ids = torch.randint(0, 256, (1, 1023))
|
144
|
+
logits = transformer(ids)
|
145
|
+
|
146
|
+
assert logits.shape == (1, 1023, 256)
|
147
|
+
|
116
148
|
def test_overriding_chunk_size():
|
117
149
|
mem = NeuralMemory(
|
118
150
|
dim = 384,
|
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
from math import ceil
|
5
|
+
from copy import deepcopy
|
5
6
|
from functools import partial
|
6
7
|
from collections import namedtuple
|
7
8
|
|
@@ -478,18 +479,21 @@ class MemoryAsContextTransformer(Module):
|
|
478
479
|
depth,
|
479
480
|
segment_len,
|
480
481
|
neural_memory_segment_len = None,
|
481
|
-
neural_mem_gate_attn_output =
|
482
|
+
neural_mem_gate_attn_output = False,
|
483
|
+
neural_memory_add_value_residual = False,
|
482
484
|
num_longterm_mem_tokens = 0,
|
483
485
|
num_persist_mem_tokens = 0,
|
484
486
|
dim_head = 64,
|
485
487
|
heads = 8,
|
486
488
|
ff_mult = 4,
|
487
489
|
num_residual_streams = 4,
|
490
|
+
neural_memory_model: Module | None = None,
|
488
491
|
neural_memory_kwargs: dict = dict(),
|
489
492
|
neural_memory_layers: tuple[int, ...] | None = None,
|
490
493
|
aux_kv_recon_loss_weight = 0.,
|
491
494
|
use_flex_attn = False,
|
492
|
-
sliding_window_attn = False
|
495
|
+
sliding_window_attn = False,
|
496
|
+
weight_tie_memory_model = False
|
493
497
|
):
|
494
498
|
super().__init__()
|
495
499
|
|
@@ -523,6 +527,20 @@ class MemoryAsContextTransformer(Module):
|
|
523
527
|
|
524
528
|
neural_memory_layers = default(neural_memory_layers, layers)
|
525
529
|
|
530
|
+
# weight tying neural memory model
|
531
|
+
|
532
|
+
maybe_copy = deepcopy if not weight_tie_memory_model else identity
|
533
|
+
|
534
|
+
if weight_tie_memory_model:
|
535
|
+
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
536
|
+
|
537
|
+
self.weight_tie_memory_model = weight_tie_memory_model
|
538
|
+
|
539
|
+
# value residual learning for neural memory
|
540
|
+
|
541
|
+
is_first_mem = True
|
542
|
+
self.mem_add_value_residual = neural_memory_add_value_residual
|
543
|
+
|
526
544
|
# mem, attn, and feedforward layers
|
527
545
|
|
528
546
|
for layer in layers:
|
@@ -551,9 +569,12 @@ class MemoryAsContextTransformer(Module):
|
|
551
569
|
mem = NeuralMemory(
|
552
570
|
dim = dim,
|
553
571
|
chunk_size = self.neural_memory_segment_len,
|
572
|
+
model = maybe_copy(neural_memory_model),
|
573
|
+
accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
|
554
574
|
**neural_memory_kwargs
|
555
575
|
)
|
556
576
|
|
577
|
+
is_first_mem = False
|
557
578
|
|
558
579
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
559
580
|
|
@@ -683,7 +704,7 @@ class MemoryAsContextTransformer(Module):
|
|
683
704
|
|
684
705
|
# math
|
685
706
|
|
686
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
707
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
|
687
708
|
|
688
709
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
689
710
|
|
@@ -736,10 +757,16 @@ class MemoryAsContextTransformer(Module):
|
|
736
757
|
next_kv_caches = []
|
737
758
|
next_neural_mem_caches = []
|
738
759
|
|
760
|
+
# weight tied neural memory
|
761
|
+
|
762
|
+
neural_memory_updates = None
|
763
|
+
|
739
764
|
# value residual
|
740
765
|
|
741
766
|
value_residual = None
|
742
767
|
|
768
|
+
mem_value_residual = None
|
769
|
+
|
743
770
|
# aux losses
|
744
771
|
|
745
772
|
kv_recon_losses = self.zero
|
@@ -767,19 +794,31 @@ class MemoryAsContextTransformer(Module):
|
|
767
794
|
mem_input, add_residual = mem_hyper_conn(x)
|
768
795
|
|
769
796
|
if not is_inferencing:
|
770
|
-
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
797
|
+
(retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
|
771
798
|
mem_input,
|
772
|
-
return_aux_kv_loss = True
|
799
|
+
return_aux_kv_loss = True,
|
800
|
+
return_values = True,
|
801
|
+
value_residual = mem_value_residual,
|
802
|
+
prev_layer_updates = neural_memory_updates
|
773
803
|
)
|
774
804
|
|
775
805
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
776
806
|
|
777
807
|
else:
|
778
|
-
retrieved, next_neural_mem_cache = mem.forward_inference(
|
808
|
+
(retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
|
779
809
|
mem_input,
|
780
|
-
state = next(neural_mem_caches, None)
|
810
|
+
state = next(neural_mem_caches, None),
|
811
|
+
return_values = True,
|
812
|
+
value_residual = mem_value_residual,
|
813
|
+
prev_layer_updates = neural_memory_updates
|
781
814
|
)
|
782
815
|
|
816
|
+
if self.mem_add_value_residual:
|
817
|
+
mem_value_residual = next_mem_value_residual
|
818
|
+
|
819
|
+
if weight_tie_memory_model:
|
820
|
+
neural_memory_updates = next_neural_mem_cache.updates
|
821
|
+
|
783
822
|
if self.gate_attn_output:
|
784
823
|
attn_out_gates = retrieved.sigmoid()
|
785
824
|
else:
|
@@ -822,6 +822,9 @@ class NeuralMemory(Module):
|
|
822
822
|
self,
|
823
823
|
token: Tensor,
|
824
824
|
state = None,
|
825
|
+
prev_layer_updates: dict[str, Tensor] | None = None,
|
826
|
+
return_values = False,
|
827
|
+
value_residual = None,
|
825
828
|
):
|
826
829
|
|
827
830
|
# unpack previous state
|
@@ -863,12 +866,18 @@ class NeuralMemory(Module):
|
|
863
866
|
else:
|
864
867
|
updates = updates.apply(lambda t: t[:, -1:])
|
865
868
|
|
869
|
+
if exists(prev_layer_updates):
|
870
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
871
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
872
|
+
|
866
873
|
if store_seq_cache_len == self.chunk_size:
|
867
874
|
|
868
|
-
next_updates, next_states,
|
875
|
+
next_updates, next_states, values = self.store_memories(
|
869
876
|
cache_store_seq,
|
870
877
|
weights,
|
871
|
-
past_state = past_states
|
878
|
+
past_state = past_states,
|
879
|
+
prev_layer_updates = prev_layer_updates,
|
880
|
+
value_residual = value_residual
|
872
881
|
)
|
873
882
|
|
874
883
|
updates = next_updates
|
@@ -882,7 +891,12 @@ class NeuralMemory(Module):
|
|
882
891
|
|
883
892
|
next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
884
893
|
|
885
|
-
|
894
|
+
output = (retrieved, next_state)
|
895
|
+
|
896
|
+
if return_values:
|
897
|
+
output = (*output, values)
|
898
|
+
|
899
|
+
return output
|
886
900
|
|
887
901
|
def forward(
|
888
902
|
self,
|
@@ -894,6 +908,7 @@ class NeuralMemory(Module):
|
|
894
908
|
chunk_size = None,
|
895
909
|
store_chunk_size = None,
|
896
910
|
return_values = False,
|
911
|
+
value_residual = None,
|
897
912
|
return_next_state = False,
|
898
913
|
prev_layer_updates: dict[str, Tensor] | None = None
|
899
914
|
):
|
@@ -927,6 +942,7 @@ class NeuralMemory(Module):
|
|
927
942
|
mem_model_weights,
|
928
943
|
chunk_size = store_chunk_size,
|
929
944
|
prev_layer_updates = prev_layer_updates,
|
945
|
+
value_residual = value_residual,
|
930
946
|
return_aux_kv_loss = True
|
931
947
|
)
|
932
948
|
|
@@ -9,7 +9,8 @@ from torch.nn import functional as F
|
|
9
9
|
from torch.utils.data import DataLoader, Dataset
|
10
10
|
|
11
11
|
from adam_atan2_pytorch import AdoptAtan2
|
12
|
-
|
12
|
+
|
13
|
+
from titans_pytorch import MemoryAsContextTransformer, MemoryMLP
|
13
14
|
|
14
15
|
# constants
|
15
16
|
|
@@ -29,13 +30,16 @@ SEQ_LEN = 512
|
|
29
30
|
NEURAL_MEMORY_DEPTH = 2
|
30
31
|
NUM_PERSIST_MEM = 4
|
31
32
|
NUM_LONGTERM_MEM = 4
|
32
|
-
NEURAL_MEM_LAYERS = (2, 4)
|
33
|
-
NEURAL_MEM_GATE_ATTN_OUTPUT =
|
33
|
+
NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
|
34
|
+
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
34
35
|
NEURAL_MEM_MOMENTUM = True
|
36
|
+
NEURAL_MEM_ADD_VALUE_RESIDUAL = True,
|
35
37
|
WINDOW_SIZE = 32
|
36
38
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
37
39
|
SLIDING_WINDOWS = True
|
40
|
+
WEIGHT_TIE_MEMORY_MODEL = True # set to have memory MLP shared across layers
|
38
41
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
42
|
+
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
39
43
|
KV_RECON_LOSS_WEIGHT = 0.
|
40
44
|
|
41
45
|
# experiment related
|
@@ -84,15 +88,19 @@ model = MemoryAsContextTransformer(
|
|
84
88
|
aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
|
85
89
|
use_flex_attn = USE_FLEX_ATTN,
|
86
90
|
sliding_window_attn = SLIDING_WINDOWS,
|
91
|
+
weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
|
92
|
+
neural_memory_add_value_residual = NEURAL_MEM_ADD_VALUE_RESIDUAL,
|
93
|
+
neural_memory_model = MemoryMLP(
|
94
|
+
dim = 64,
|
95
|
+
depth = NEURAL_MEMORY_DEPTH
|
96
|
+
),
|
87
97
|
neural_memory_kwargs = dict(
|
88
98
|
dim_head = 64,
|
89
99
|
heads = 4,
|
90
100
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
91
101
|
momentum = NEURAL_MEM_MOMENTUM,
|
92
102
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
93
|
-
|
94
|
-
depth = NEURAL_MEMORY_DEPTH,
|
95
|
-
)
|
103
|
+
per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
|
96
104
|
)
|
97
105
|
).cuda()
|
98
106
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|