titans-pytorch 0.1.34__tar.gz → 0.1.36__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/PKG-INFO +1 -1
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/pyproject.toml +1 -1
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/titans_pytorch/mac_transformer.py +20 -3
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/titans_pytorch/neural_memory.py +16 -3
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/train_mac.py +14 -6
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/.gitignore +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/LICENSE +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/README.md +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/data/README.md +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/fig1.png +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/fig2.png +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/tests/test_titans.py +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.34 → titans_pytorch-0.1.36}/titans_pytorch/associative_scan.py +0 -0
@@ -479,7 +479,8 @@ class MemoryAsContextTransformer(Module):
|
|
479
479
|
depth,
|
480
480
|
segment_len,
|
481
481
|
neural_memory_segment_len = None,
|
482
|
-
neural_mem_gate_attn_output =
|
482
|
+
neural_mem_gate_attn_output = False,
|
483
|
+
neural_memory_add_value_residual = False,
|
483
484
|
num_longterm_mem_tokens = 0,
|
484
485
|
num_persist_mem_tokens = 0,
|
485
486
|
dim_head = 64,
|
@@ -535,6 +536,11 @@ class MemoryAsContextTransformer(Module):
|
|
535
536
|
|
536
537
|
self.weight_tie_memory_model = weight_tie_memory_model
|
537
538
|
|
539
|
+
# value residual learning for neural memory
|
540
|
+
|
541
|
+
is_first_mem = True
|
542
|
+
self.mem_add_value_residual = neural_memory_add_value_residual
|
543
|
+
|
538
544
|
# mem, attn, and feedforward layers
|
539
545
|
|
540
546
|
for layer in layers:
|
@@ -564,9 +570,11 @@ class MemoryAsContextTransformer(Module):
|
|
564
570
|
dim = dim,
|
565
571
|
chunk_size = self.neural_memory_segment_len,
|
566
572
|
model = maybe_copy(neural_memory_model),
|
573
|
+
accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
|
567
574
|
**neural_memory_kwargs
|
568
575
|
)
|
569
576
|
|
577
|
+
is_first_mem = False
|
570
578
|
|
571
579
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
572
580
|
|
@@ -757,6 +765,8 @@ class MemoryAsContextTransformer(Module):
|
|
757
765
|
|
758
766
|
value_residual = None
|
759
767
|
|
768
|
+
mem_value_residual = None
|
769
|
+
|
760
770
|
# aux losses
|
761
771
|
|
762
772
|
kv_recon_losses = self.zero
|
@@ -784,21 +794,28 @@ class MemoryAsContextTransformer(Module):
|
|
784
794
|
mem_input, add_residual = mem_hyper_conn(x)
|
785
795
|
|
786
796
|
if not is_inferencing:
|
787
|
-
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
797
|
+
(retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
|
788
798
|
mem_input,
|
789
799
|
return_aux_kv_loss = True,
|
800
|
+
return_values = True,
|
801
|
+
value_residual = mem_value_residual,
|
790
802
|
prev_layer_updates = neural_memory_updates
|
791
803
|
)
|
792
804
|
|
793
805
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
794
806
|
|
795
807
|
else:
|
796
|
-
retrieved, next_neural_mem_cache = mem.forward_inference(
|
808
|
+
(retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
|
797
809
|
mem_input,
|
798
810
|
state = next(neural_mem_caches, None),
|
811
|
+
return_values = True,
|
812
|
+
value_residual = mem_value_residual,
|
799
813
|
prev_layer_updates = neural_memory_updates
|
800
814
|
)
|
801
815
|
|
816
|
+
if self.mem_add_value_residual:
|
817
|
+
mem_value_residual = next_mem_value_residual
|
818
|
+
|
802
819
|
if weight_tie_memory_model:
|
803
820
|
neural_memory_updates = next_neural_mem_cache.updates
|
804
821
|
|
@@ -822,7 +822,9 @@ class NeuralMemory(Module):
|
|
822
822
|
self,
|
823
823
|
token: Tensor,
|
824
824
|
state = None,
|
825
|
-
prev_layer_updates: dict[str, Tensor] | None = None
|
825
|
+
prev_layer_updates: dict[str, Tensor] | None = None,
|
826
|
+
return_values = False,
|
827
|
+
value_residual = None,
|
826
828
|
):
|
827
829
|
|
828
830
|
# unpack previous state
|
@@ -870,11 +872,12 @@ class NeuralMemory(Module):
|
|
870
872
|
|
871
873
|
if store_seq_cache_len == self.chunk_size:
|
872
874
|
|
873
|
-
next_updates, next_states,
|
875
|
+
next_updates, next_states, values = self.store_memories(
|
874
876
|
cache_store_seq,
|
875
877
|
weights,
|
876
878
|
past_state = past_states,
|
877
879
|
prev_layer_updates = prev_layer_updates,
|
880
|
+
value_residual = value_residual
|
878
881
|
)
|
879
882
|
|
880
883
|
updates = next_updates
|
@@ -888,7 +891,12 @@ class NeuralMemory(Module):
|
|
888
891
|
|
889
892
|
next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
890
893
|
|
891
|
-
|
894
|
+
output = (retrieved, next_state)
|
895
|
+
|
896
|
+
if return_values:
|
897
|
+
output = (*output, values)
|
898
|
+
|
899
|
+
return output
|
892
900
|
|
893
901
|
def forward(
|
894
902
|
self,
|
@@ -900,6 +908,7 @@ class NeuralMemory(Module):
|
|
900
908
|
chunk_size = None,
|
901
909
|
store_chunk_size = None,
|
902
910
|
return_values = False,
|
911
|
+
value_residual = None,
|
903
912
|
return_next_state = False,
|
904
913
|
prev_layer_updates: dict[str, Tensor] | None = None
|
905
914
|
):
|
@@ -912,6 +921,9 @@ class NeuralMemory(Module):
|
|
912
921
|
|
913
922
|
out = (out, next_store_state)
|
914
923
|
|
924
|
+
if return_values:
|
925
|
+
out = (*out, self.zero)
|
926
|
+
|
915
927
|
if not return_aux_kv_loss:
|
916
928
|
return out
|
917
929
|
|
@@ -933,6 +945,7 @@ class NeuralMemory(Module):
|
|
933
945
|
mem_model_weights,
|
934
946
|
chunk_size = store_chunk_size,
|
935
947
|
prev_layer_updates = prev_layer_updates,
|
948
|
+
value_residual = value_residual,
|
936
949
|
return_aux_kv_loss = True
|
937
950
|
)
|
938
951
|
|
@@ -9,7 +9,8 @@ from torch.nn import functional as F
|
|
9
9
|
from torch.utils.data import DataLoader, Dataset
|
10
10
|
|
11
11
|
from adam_atan2_pytorch import AdoptAtan2
|
12
|
-
|
12
|
+
|
13
|
+
from titans_pytorch import MemoryAsContextTransformer, MemoryMLP
|
13
14
|
|
14
15
|
# constants
|
15
16
|
|
@@ -29,13 +30,16 @@ SEQ_LEN = 512
|
|
29
30
|
NEURAL_MEMORY_DEPTH = 2
|
30
31
|
NUM_PERSIST_MEM = 4
|
31
32
|
NUM_LONGTERM_MEM = 4
|
32
|
-
NEURAL_MEM_LAYERS = (2, 4)
|
33
|
-
NEURAL_MEM_GATE_ATTN_OUTPUT =
|
33
|
+
NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
|
34
|
+
NEURAL_MEM_GATE_ATTN_OUTPUT = False
|
34
35
|
NEURAL_MEM_MOMENTUM = True
|
36
|
+
NEURAL_MEM_ADD_VALUE_RESIDUAL = True,
|
35
37
|
WINDOW_SIZE = 32
|
36
38
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
37
39
|
SLIDING_WINDOWS = True
|
40
|
+
WEIGHT_TIE_MEMORY_MODEL = True # set to have memory MLP shared across layers
|
38
41
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
42
|
+
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
39
43
|
KV_RECON_LOSS_WEIGHT = 0.
|
40
44
|
|
41
45
|
# experiment related
|
@@ -84,15 +88,19 @@ model = MemoryAsContextTransformer(
|
|
84
88
|
aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
|
85
89
|
use_flex_attn = USE_FLEX_ATTN,
|
86
90
|
sliding_window_attn = SLIDING_WINDOWS,
|
91
|
+
weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
|
92
|
+
neural_memory_add_value_residual = NEURAL_MEM_ADD_VALUE_RESIDUAL,
|
93
|
+
neural_memory_model = MemoryMLP(
|
94
|
+
dim = 64,
|
95
|
+
depth = NEURAL_MEMORY_DEPTH
|
96
|
+
),
|
87
97
|
neural_memory_kwargs = dict(
|
88
98
|
dim_head = 64,
|
89
99
|
heads = 4,
|
90
100
|
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
91
101
|
momentum = NEURAL_MEM_MOMENTUM,
|
92
102
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
93
|
-
|
94
|
-
depth = NEURAL_MEMORY_DEPTH,
|
95
|
-
)
|
103
|
+
per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
|
96
104
|
)
|
97
105
|
).cuda()
|
98
106
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|