titans-pytorch 0.1.32__tar.gz → 0.1.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/PKG-INFO +1 -1
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/pyproject.toml +1 -1
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/tests/test_titans.py +28 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/titans_pytorch/neural_memory.py +38 -10
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/.gitignore +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/LICENSE +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/README.md +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/data/README.md +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/fig1.png +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/fig2.png +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.1.32 → titans_pytorch-0.1.33}/train_mac.py +0 -0
@@ -85,6 +85,34 @@ def test_retrieve_store_diff_seq():
|
|
85
85
|
|
86
86
|
assert retrieve_seq.shape == retrieved.shape
|
87
87
|
|
88
|
+
def test_weight_tied_mlp_neural_mem():
|
89
|
+
mem = NeuralMemory(
|
90
|
+
dim = 384,
|
91
|
+
dim_head = 64,
|
92
|
+
heads = 2,
|
93
|
+
chunk_size = 2
|
94
|
+
)
|
95
|
+
|
96
|
+
mem2 = NeuralMemory(
|
97
|
+
dim = 384,
|
98
|
+
dim_head = 64,
|
99
|
+
heads = 2,
|
100
|
+
chunk_size = 2
|
101
|
+
)
|
102
|
+
|
103
|
+
mem3 = NeuralMemory(
|
104
|
+
dim = 384,
|
105
|
+
dim_head = 64,
|
106
|
+
heads = 2,
|
107
|
+
chunk_size = 2
|
108
|
+
)
|
109
|
+
|
110
|
+
seq = torch.randn(2, 128, 384)
|
111
|
+
|
112
|
+
seq, cache = mem(seq)
|
113
|
+
seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
|
114
|
+
seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
|
115
|
+
|
88
116
|
def test_overriding_chunk_size():
|
89
117
|
mem = NeuralMemory(
|
90
118
|
dim = 384,
|
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from collections import namedtuple
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn, cat, Tensor
|
@@ -33,6 +34,8 @@ w - num memory network weight parameters
|
|
33
34
|
|
34
35
|
LinearNoBias = partial(Linear, bias = False)
|
35
36
|
|
37
|
+
NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
|
38
|
+
|
36
39
|
# functions
|
37
40
|
|
38
41
|
def exists(v):
|
@@ -605,7 +608,7 @@ class NeuralMemory(Module):
|
|
605
608
|
# improvise (or perhaps correcting to) a solution
|
606
609
|
|
607
610
|
if exists(prev_layer_updates):
|
608
|
-
prev_layer_updates = TensorDict(
|
611
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
609
612
|
|
610
613
|
weights = weights + prev_layer_updates
|
611
614
|
|
@@ -657,6 +660,11 @@ class NeuralMemory(Module):
|
|
657
660
|
|
658
661
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
659
662
|
|
663
|
+
# flatten batch and time if surprise depends on previous layer memory model
|
664
|
+
|
665
|
+
if exists(prev_layer_updates):
|
666
|
+
weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
667
|
+
|
660
668
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
661
669
|
|
662
670
|
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
@@ -737,7 +745,8 @@ class NeuralMemory(Module):
|
|
737
745
|
self,
|
738
746
|
seq,
|
739
747
|
past_weights: dict[str, Tensor],
|
740
|
-
chunk_size = None
|
748
|
+
chunk_size = None,
|
749
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
741
750
|
):
|
742
751
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
743
752
|
batch, seq_len = seq.shape[:2]
|
@@ -760,6 +769,9 @@ class NeuralMemory(Module):
|
|
760
769
|
|
761
770
|
curr_weights = TensorDict(past_weights)
|
762
771
|
|
772
|
+
if exists(prev_layer_updates):
|
773
|
+
curr_weights = curr_weights + TensorDict(prev_layer_updates)
|
774
|
+
|
763
775
|
# sequence Float['b n d'] to queries
|
764
776
|
|
765
777
|
queries = self.to_queries(seq)
|
@@ -838,7 +850,7 @@ class NeuralMemory(Module):
|
|
838
850
|
if curr_seq_len < self.chunk_size:
|
839
851
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
840
852
|
|
841
|
-
return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
|
853
|
+
return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
842
854
|
|
843
855
|
# store if storage sequence cache hits the chunk size
|
844
856
|
|
@@ -848,6 +860,8 @@ class NeuralMemory(Module):
|
|
848
860
|
if not exists(updates):
|
849
861
|
updates = weights.clone().zero_()
|
850
862
|
updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
|
863
|
+
else:
|
864
|
+
updates = updates.apply(lambda t: t[:, -1:])
|
851
865
|
|
852
866
|
if store_seq_cache_len == self.chunk_size:
|
853
867
|
|
@@ -866,7 +880,7 @@ class NeuralMemory(Module):
|
|
866
880
|
|
867
881
|
# next state tuple
|
868
882
|
|
869
|
-
next_state = (curr_seq_len, cache_store_seq, next_states, updates)
|
883
|
+
next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
870
884
|
|
871
885
|
return retrieved, next_state
|
872
886
|
|
@@ -880,7 +894,8 @@ class NeuralMemory(Module):
|
|
880
894
|
chunk_size = None,
|
881
895
|
store_chunk_size = None,
|
882
896
|
return_values = False,
|
883
|
-
return_next_state = False
|
897
|
+
return_next_state = False,
|
898
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
884
899
|
):
|
885
900
|
batch, seq_len = seq.shape[:2]
|
886
901
|
|
@@ -899,15 +914,30 @@ class NeuralMemory(Module):
|
|
899
914
|
if not exists(mem_model_weights):
|
900
915
|
mem_model_weights = self.init_weights()
|
901
916
|
|
917
|
+
# store
|
918
|
+
|
902
919
|
store_seq = default(store_seq, seq)
|
903
920
|
|
904
921
|
store_seq_len = store_seq.shape[-2]
|
905
922
|
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
906
923
|
remainder = store_seq_len % store_chunk_size
|
907
924
|
|
908
|
-
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
925
|
+
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
926
|
+
store_seq,
|
927
|
+
mem_model_weights,
|
928
|
+
chunk_size = store_chunk_size,
|
929
|
+
prev_layer_updates = prev_layer_updates,
|
930
|
+
return_aux_kv_loss = True
|
931
|
+
)
|
909
932
|
|
910
|
-
|
933
|
+
# retrieve
|
934
|
+
|
935
|
+
retrieved = self.retrieve_memories(
|
936
|
+
seq,
|
937
|
+
mem_model_weights + updates,
|
938
|
+
chunk_size = chunk_size,
|
939
|
+
prev_layer_updates = prev_layer_updates
|
940
|
+
)
|
911
941
|
|
912
942
|
# determine state for the storing of memories
|
913
943
|
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
@@ -917,9 +947,7 @@ class NeuralMemory(Module):
|
|
917
947
|
if remainder > 0:
|
918
948
|
cache_store_seq = store_seq[:, -remainder:]
|
919
949
|
|
920
|
-
|
921
|
-
|
922
|
-
next_store_state = (seq_len, cache_store_seq, next_state, updates)
|
950
|
+
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
923
951
|
|
924
952
|
output = (retrieved, next_store_state)
|
925
953
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|