titans-pytorch 0.3.24__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/PKG-INFO +1 -1
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/pyproject.toml +1 -1
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/tests/test_titans.py +10 -2
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/titans_pytorch/mac_transformer.py +56 -13
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/titans_pytorch/neural_memory.py +20 -4
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/train_mac.py +2 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/.gitignore +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/LICENSE +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/README.md +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/data/README.md +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/fig1.png +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/fig2.png +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.24 → titans_pytorch-0.4.0}/titans_pytorch/memory_models.py +0 -0
@@ -34,6 +34,7 @@ def torch_default_dtype(dtype):
|
|
34
34
|
@pytest.mark.parametrize('num_kv_per_token', (1, 2))
|
35
35
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
36
36
|
@pytest.mark.parametrize('per_head_learned_parameters', (False, True))
|
37
|
+
@pytest.mark.parametrize('test_store_mask', (False, True))
|
37
38
|
def test_titans(
|
38
39
|
seq_len,
|
39
40
|
silu,
|
@@ -45,7 +46,8 @@ def test_titans(
|
|
45
46
|
max_grad_norm,
|
46
47
|
num_kv_per_token,
|
47
48
|
per_parameter_lr_modulation,
|
48
|
-
per_head_learned_parameters
|
49
|
+
per_head_learned_parameters,
|
50
|
+
test_store_mask
|
49
51
|
):
|
50
52
|
mem = NeuralMemory(
|
51
53
|
dim = 16,
|
@@ -62,7 +64,13 @@ def test_titans(
|
|
62
64
|
)
|
63
65
|
|
64
66
|
seq = torch.randn(2, seq_len, 16)
|
65
|
-
|
67
|
+
|
68
|
+
store_mask = None
|
69
|
+
|
70
|
+
if test_store_mask:
|
71
|
+
store_mask = torch.randint(0, 2, (2, seq_len)).bool()
|
72
|
+
|
73
|
+
retrieved, _ = mem(seq, store_mask = store_mask)
|
66
74
|
|
67
75
|
assert seq.shape == retrieved.shape
|
68
76
|
|
@@ -46,7 +46,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False
|
|
46
46
|
|
47
47
|
# einstein notation related
|
48
48
|
|
49
|
-
from einops import repeat, rearrange, pack, unpack
|
49
|
+
from einops import repeat, rearrange, pack, unpack, einsum
|
50
50
|
from einops.layers.torch import Rearrange
|
51
51
|
|
52
52
|
# b - batch
|
@@ -521,9 +521,7 @@ class MemoryAsContextTransformer(Module):
|
|
521
521
|
self.sliding_window_attn = sliding_window_attn
|
522
522
|
self.attn_window_size = segment_len + num_longterm_mem_tokens
|
523
523
|
|
524
|
-
# hyper
|
525
|
-
|
526
|
-
assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
|
524
|
+
# hyper connection
|
527
525
|
|
528
526
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
529
527
|
|
@@ -560,17 +558,28 @@ class MemoryAsContextTransformer(Module):
|
|
560
558
|
)
|
561
559
|
|
562
560
|
mem = None
|
561
|
+
mem_qkv_layer_selector = None
|
563
562
|
mem_hyper_conn = None
|
564
563
|
|
565
564
|
if layer in neural_memory_layers:
|
566
|
-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output
|
565
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
566
|
+
|
567
|
+
if not is_first and neural_memory_qkv_receives_diff_views:
|
568
|
+
num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
|
569
|
+
|
570
|
+
mem_qkv_layer_selector = nn.Sequential(
|
571
|
+
nn.RMSNorm(dim),
|
572
|
+
nn.Linear(dim, 3 * num_layer_choices),
|
573
|
+
Rearrange('... (views layers) -> views ... layers', views = 3),
|
574
|
+
nn.Softmax(dim = -1)
|
575
|
+
)
|
567
576
|
|
568
577
|
mem = NeuralMemory(
|
569
578
|
dim = dim,
|
570
579
|
chunk_size = self.neural_memory_segment_len,
|
571
580
|
batch_size = neural_memory_batch_size,
|
572
581
|
model = deepcopy(neural_memory_model),
|
573
|
-
qkv_receives_diff_views =
|
582
|
+
qkv_receives_diff_views = True,
|
574
583
|
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
575
584
|
**neural_memory_kwargs
|
576
585
|
)
|
@@ -581,9 +590,12 @@ class MemoryAsContextTransformer(Module):
|
|
581
590
|
|
582
591
|
self.layers.append(ModuleList([
|
583
592
|
mem_hyper_conn,
|
593
|
+
init_hyper_conn(),
|
594
|
+
init_hyper_conn(),
|
595
|
+
mem_qkv_layer_selector,
|
584
596
|
mem,
|
585
|
-
|
586
|
-
|
597
|
+
attn,
|
598
|
+
ff,
|
587
599
|
]))
|
588
600
|
|
589
601
|
self.norm = nn.RMSNorm(dim)
|
@@ -763,6 +775,10 @@ class MemoryAsContextTransformer(Module):
|
|
763
775
|
|
764
776
|
mem_weight_residual = None
|
765
777
|
|
778
|
+
# layers for the neural mem to select the qkv inputs from
|
779
|
+
|
780
|
+
mem_input_layers = []
|
781
|
+
|
766
782
|
# when inferencing, only do one token at a time
|
767
783
|
|
768
784
|
if is_inferencing:
|
@@ -773,7 +789,7 @@ class MemoryAsContextTransformer(Module):
|
|
773
789
|
|
774
790
|
x = self.expand_streams(x)
|
775
791
|
|
776
|
-
for mem_hyper_conn, mem, attn, ff in self.layers:
|
792
|
+
for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
|
777
793
|
|
778
794
|
retrieved = None
|
779
795
|
attn_out_gates = None
|
@@ -785,8 +801,19 @@ class MemoryAsContextTransformer(Module):
|
|
785
801
|
|
786
802
|
mem_input, add_residual = mem_hyper_conn(x)
|
787
803
|
|
804
|
+
if not exists(mem_qkv_layer_selector):
|
805
|
+
qkv_mem_input = stack((mem_input, mem_input, mem_input))
|
806
|
+
else:
|
807
|
+
layers_to_choose_from = stack((mem_input, *mem_input_layers))
|
808
|
+
|
809
|
+
# let the current `mem_input` select the 3 layers for qkv
|
810
|
+
|
811
|
+
selected = mem_qkv_layer_selector(mem_input)
|
812
|
+
|
813
|
+
qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
|
814
|
+
|
788
815
|
retrieved, next_neural_mem_cache = mem.forward(
|
789
|
-
|
816
|
+
qkv_mem_input,
|
790
817
|
state = next(neural_mem_caches, None),
|
791
818
|
prev_weights = mem_weight_residual
|
792
819
|
)
|
@@ -801,8 +828,12 @@ class MemoryAsContextTransformer(Module):
|
|
801
828
|
|
802
829
|
# attention
|
803
830
|
|
804
|
-
|
805
|
-
|
831
|
+
attn_in, add_residual = attn_hyper_conn(x)
|
832
|
+
|
833
|
+
mem_input_layers.append(attn_in)
|
834
|
+
|
835
|
+
attn_out, (values, next_kv_cache) = attn(
|
836
|
+
attn_in,
|
806
837
|
value_residual = value_residual,
|
807
838
|
disable_flex_attn = disable_flex_attn,
|
808
839
|
flex_attn_fn = flex_attn_fn,
|
@@ -810,8 +841,12 @@ class MemoryAsContextTransformer(Module):
|
|
810
841
|
cache = next(kv_caches, None)
|
811
842
|
)
|
812
843
|
|
844
|
+
mem_input_layers.append(attn_out)
|
845
|
+
|
813
846
|
value_residual = default(value_residual, values)
|
814
847
|
|
848
|
+
x = add_residual(attn_out)
|
849
|
+
|
815
850
|
# caches
|
816
851
|
|
817
852
|
next_kv_caches.append(next_kv_cache)
|
@@ -819,7 +854,15 @@ class MemoryAsContextTransformer(Module):
|
|
819
854
|
|
820
855
|
# feedforward
|
821
856
|
|
822
|
-
|
857
|
+
ff_in, add_ff_residual = ff_hyper_conn(x)
|
858
|
+
|
859
|
+
mem_input_layers.append(ff_in)
|
860
|
+
|
861
|
+
ff_out = ff(ff_in)
|
862
|
+
|
863
|
+
mem_input_layers.append(ff_out)
|
864
|
+
|
865
|
+
x = add_ff_residual(ff_out)
|
823
866
|
|
824
867
|
# taking care of cache first
|
825
868
|
# for early return when processing long term mem tokens during inference
|
@@ -524,7 +524,8 @@ class NeuralMemory(Module):
|
|
524
524
|
weights: dict[str, Tensor] | None = None,
|
525
525
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
526
526
|
seq_index = 0,
|
527
|
-
prev_weights = None
|
527
|
+
prev_weights = None,
|
528
|
+
mask: Tensor | None = None,
|
528
529
|
):
|
529
530
|
if self.qkv_receives_diff_views:
|
530
531
|
_, batch, seq_len = seq.shape[:3]
|
@@ -612,6 +613,14 @@ class NeuralMemory(Module):
|
|
612
613
|
|
613
614
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
|
614
615
|
|
616
|
+
# optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
|
617
|
+
|
618
|
+
if exists(mask):
|
619
|
+
mask = mask[..., :round_down_seq_len]
|
620
|
+
mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
|
621
|
+
|
622
|
+
adaptive_lr = torch.where(mask, adaptive_lr, 0.)
|
623
|
+
|
615
624
|
# maybe add previous layer weight
|
616
625
|
|
617
626
|
assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
|
@@ -833,7 +842,8 @@ class NeuralMemory(Module):
|
|
833
842
|
seq,
|
834
843
|
store_seq = None,
|
835
844
|
state: NeuralMemState | None = None,
|
836
|
-
prev_weights = None
|
845
|
+
prev_weights = None,
|
846
|
+
store_mask: Tensor | None = None
|
837
847
|
):
|
838
848
|
is_multi_input = self.qkv_receives_diff_views
|
839
849
|
|
@@ -910,6 +920,11 @@ class NeuralMemory(Module):
|
|
910
920
|
|
911
921
|
store_seqs = store_seq.split(split_sizes, dim = -2)
|
912
922
|
|
923
|
+
if exists(store_mask):
|
924
|
+
store_masks = store_mask.split(split_sizes, dim = -1)
|
925
|
+
else:
|
926
|
+
store_masks = (None,) * len(split_sizes)
|
927
|
+
|
913
928
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
914
929
|
|
915
930
|
gate = None
|
@@ -917,7 +932,7 @@ class NeuralMemory(Module):
|
|
917
932
|
if exists(self.transition_gate):
|
918
933
|
gate = self.transition_gate.sigmoid()
|
919
934
|
|
920
|
-
for ind, store_seq_chunk in enumerate(store_seqs):
|
935
|
+
for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
|
921
936
|
is_last = ind == (len(store_seqs) - 1)
|
922
937
|
|
923
938
|
# store
|
@@ -927,7 +942,8 @@ class NeuralMemory(Module):
|
|
927
942
|
weights,
|
928
943
|
seq_index = seq_index,
|
929
944
|
past_state = past_state,
|
930
|
-
prev_weights = prev_weights
|
945
|
+
prev_weights = prev_weights,
|
946
|
+
mask = maybe_store_mask
|
931
947
|
)
|
932
948
|
|
933
949
|
weights = next_neural_mem_state.weights
|
@@ -48,6 +48,7 @@ SLIDING_WINDOWS = True
|
|
48
48
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
49
49
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
50
50
|
NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper
|
51
|
+
NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS
|
51
52
|
|
52
53
|
# experiment related
|
53
54
|
|
@@ -107,6 +108,7 @@ model = MemoryAsContextTransformer(
|
|
107
108
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
108
109
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
109
110
|
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
111
|
+
neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW,
|
110
112
|
use_flex_attn = USE_FLEX_ATTN,
|
111
113
|
sliding_window_attn = SLIDING_WINDOWS,
|
112
114
|
neural_memory_model = neural_memory_model,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|