titans-pytorch 0.3.25__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/PKG-INFO +1 -1
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/pyproject.toml +1 -1
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/tests/test_titans.py +15 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/titans_pytorch/mac_transformer.py +56 -13
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/titans_pytorch/neural_memory.py +30 -9
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/train_mac.py +2 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/.gitignore +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/LICENSE +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/README.md +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/data/README.md +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/fig1.png +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/fig2.png +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.25 → titans_pytorch-0.4.1}/titans_pytorch/memory_models.py +0 -0
@@ -74,6 +74,21 @@ def test_titans(
|
|
74
74
|
|
75
75
|
assert seq.shape == retrieved.shape
|
76
76
|
|
77
|
+
def test_return_surprises():
|
78
|
+
|
79
|
+
mem = NeuralMemory(
|
80
|
+
dim = 384,
|
81
|
+
chunk_size = 2,
|
82
|
+
dim_head = 64,
|
83
|
+
heads = 4,
|
84
|
+
)
|
85
|
+
|
86
|
+
seq = torch.randn(4, 64, 384)
|
87
|
+
|
88
|
+
_, _, surprises = mem(seq, return_surprises = True)
|
89
|
+
|
90
|
+
assert surprises.shape == (4, 4, 64)
|
91
|
+
|
77
92
|
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
|
78
93
|
@pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
|
79
94
|
def test_titans_second_order_momentum(
|
@@ -46,7 +46,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False
|
|
46
46
|
|
47
47
|
# einstein notation related
|
48
48
|
|
49
|
-
from einops import repeat, rearrange, pack, unpack
|
49
|
+
from einops import repeat, rearrange, pack, unpack, einsum
|
50
50
|
from einops.layers.torch import Rearrange
|
51
51
|
|
52
52
|
# b - batch
|
@@ -521,9 +521,7 @@ class MemoryAsContextTransformer(Module):
|
|
521
521
|
self.sliding_window_attn = sliding_window_attn
|
522
522
|
self.attn_window_size = segment_len + num_longterm_mem_tokens
|
523
523
|
|
524
|
-
# hyper
|
525
|
-
|
526
|
-
assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
|
524
|
+
# hyper connection
|
527
525
|
|
528
526
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
|
529
527
|
|
@@ -560,17 +558,28 @@ class MemoryAsContextTransformer(Module):
|
|
560
558
|
)
|
561
559
|
|
562
560
|
mem = None
|
561
|
+
mem_qkv_layer_selector = None
|
563
562
|
mem_hyper_conn = None
|
564
563
|
|
565
564
|
if layer in neural_memory_layers:
|
566
|
-
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output
|
565
|
+
mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
|
566
|
+
|
567
|
+
if not is_first and neural_memory_qkv_receives_diff_views:
|
568
|
+
num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
|
569
|
+
|
570
|
+
mem_qkv_layer_selector = nn.Sequential(
|
571
|
+
nn.RMSNorm(dim),
|
572
|
+
nn.Linear(dim, 3 * num_layer_choices),
|
573
|
+
Rearrange('... (views layers) -> views ... layers', views = 3),
|
574
|
+
nn.Softmax(dim = -1)
|
575
|
+
)
|
567
576
|
|
568
577
|
mem = NeuralMemory(
|
569
578
|
dim = dim,
|
570
579
|
chunk_size = self.neural_memory_segment_len,
|
571
580
|
batch_size = neural_memory_batch_size,
|
572
581
|
model = deepcopy(neural_memory_model),
|
573
|
-
qkv_receives_diff_views =
|
582
|
+
qkv_receives_diff_views = True,
|
574
583
|
accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
|
575
584
|
**neural_memory_kwargs
|
576
585
|
)
|
@@ -581,9 +590,12 @@ class MemoryAsContextTransformer(Module):
|
|
581
590
|
|
582
591
|
self.layers.append(ModuleList([
|
583
592
|
mem_hyper_conn,
|
593
|
+
init_hyper_conn(),
|
594
|
+
init_hyper_conn(),
|
595
|
+
mem_qkv_layer_selector,
|
584
596
|
mem,
|
585
|
-
|
586
|
-
|
597
|
+
attn,
|
598
|
+
ff,
|
587
599
|
]))
|
588
600
|
|
589
601
|
self.norm = nn.RMSNorm(dim)
|
@@ -763,6 +775,10 @@ class MemoryAsContextTransformer(Module):
|
|
763
775
|
|
764
776
|
mem_weight_residual = None
|
765
777
|
|
778
|
+
# layers for the neural mem to select the qkv inputs from
|
779
|
+
|
780
|
+
mem_input_layers = []
|
781
|
+
|
766
782
|
# when inferencing, only do one token at a time
|
767
783
|
|
768
784
|
if is_inferencing:
|
@@ -773,7 +789,7 @@ class MemoryAsContextTransformer(Module):
|
|
773
789
|
|
774
790
|
x = self.expand_streams(x)
|
775
791
|
|
776
|
-
for mem_hyper_conn, mem, attn, ff in self.layers:
|
792
|
+
for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
|
777
793
|
|
778
794
|
retrieved = None
|
779
795
|
attn_out_gates = None
|
@@ -785,8 +801,19 @@ class MemoryAsContextTransformer(Module):
|
|
785
801
|
|
786
802
|
mem_input, add_residual = mem_hyper_conn(x)
|
787
803
|
|
804
|
+
if not exists(mem_qkv_layer_selector):
|
805
|
+
qkv_mem_input = stack((mem_input, mem_input, mem_input))
|
806
|
+
else:
|
807
|
+
layers_to_choose_from = stack((mem_input, *mem_input_layers))
|
808
|
+
|
809
|
+
# let the current `mem_input` select the 3 layers for qkv
|
810
|
+
|
811
|
+
selected = mem_qkv_layer_selector(mem_input)
|
812
|
+
|
813
|
+
qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
|
814
|
+
|
788
815
|
retrieved, next_neural_mem_cache = mem.forward(
|
789
|
-
|
816
|
+
qkv_mem_input,
|
790
817
|
state = next(neural_mem_caches, None),
|
791
818
|
prev_weights = mem_weight_residual
|
792
819
|
)
|
@@ -801,8 +828,12 @@ class MemoryAsContextTransformer(Module):
|
|
801
828
|
|
802
829
|
# attention
|
803
830
|
|
804
|
-
|
805
|
-
|
831
|
+
attn_in, add_residual = attn_hyper_conn(x)
|
832
|
+
|
833
|
+
mem_input_layers.append(attn_in)
|
834
|
+
|
835
|
+
attn_out, (values, next_kv_cache) = attn(
|
836
|
+
attn_in,
|
806
837
|
value_residual = value_residual,
|
807
838
|
disable_flex_attn = disable_flex_attn,
|
808
839
|
flex_attn_fn = flex_attn_fn,
|
@@ -810,8 +841,12 @@ class MemoryAsContextTransformer(Module):
|
|
810
841
|
cache = next(kv_caches, None)
|
811
842
|
)
|
812
843
|
|
844
|
+
mem_input_layers.append(attn_out)
|
845
|
+
|
813
846
|
value_residual = default(value_residual, values)
|
814
847
|
|
848
|
+
x = add_residual(attn_out)
|
849
|
+
|
815
850
|
# caches
|
816
851
|
|
817
852
|
next_kv_caches.append(next_kv_cache)
|
@@ -819,7 +854,15 @@ class MemoryAsContextTransformer(Module):
|
|
819
854
|
|
820
855
|
# feedforward
|
821
856
|
|
822
|
-
|
857
|
+
ff_in, add_ff_residual = ff_hyper_conn(x)
|
858
|
+
|
859
|
+
mem_input_layers.append(ff_in)
|
860
|
+
|
861
|
+
ff_out = ff(ff_in)
|
862
|
+
|
863
|
+
mem_input_layers.append(ff_out)
|
864
|
+
|
865
|
+
x = add_ff_residual(ff_out)
|
823
866
|
|
824
867
|
# taking care of cache first
|
825
868
|
# for early return when processing long term mem tokens during inference
|
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
|
|
353
353
|
pred = functional_call(self.memory_model, params, inputs)
|
354
354
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
355
355
|
weighted_loss = loss * loss_weights
|
356
|
-
return weighted_loss.sum()
|
356
|
+
return weighted_loss.sum(), loss
|
357
357
|
|
358
358
|
# two functions
|
359
359
|
|
360
|
-
grad_fn = grad(forward_and_loss)
|
360
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
361
361
|
|
362
362
|
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
|
363
363
|
|
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
|
|
526
526
|
seq_index = 0,
|
527
527
|
prev_weights = None,
|
528
528
|
mask: Tensor | None = None,
|
529
|
+
return_surprises = True
|
529
530
|
):
|
530
531
|
if self.qkv_receives_diff_views:
|
531
532
|
_, batch, seq_len = seq.shape[:3]
|
@@ -645,10 +646,14 @@ class NeuralMemory(Module):
|
|
645
646
|
|
646
647
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
647
648
|
|
648
|
-
grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
|
+
grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
649
650
|
|
650
651
|
grads = TensorDict(grads)
|
651
652
|
|
653
|
+
# surprises
|
654
|
+
|
655
|
+
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
656
|
+
|
652
657
|
# maybe softclamp grad norm
|
653
658
|
|
654
659
|
if exists(self.max_grad_norm):
|
@@ -687,7 +692,10 @@ class NeuralMemory(Module):
|
|
687
692
|
|
688
693
|
output = (updates, next_store_state)
|
689
694
|
|
690
|
-
|
695
|
+
if not return_surprises:
|
696
|
+
return output
|
697
|
+
|
698
|
+
return (*output, unweighted_mem_model_loss)
|
691
699
|
|
692
700
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
693
701
|
|
@@ -744,7 +752,10 @@ class NeuralMemory(Module):
|
|
744
752
|
|
745
753
|
# return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
|
746
754
|
|
747
|
-
|
755
|
+
if not return_surprises:
|
756
|
+
return updates, next_store_state
|
757
|
+
|
758
|
+
return updates, next_store_state, unweighted_mem_model_loss
|
748
759
|
|
749
760
|
def retrieve_memories(
|
750
761
|
self,
|
@@ -843,7 +854,8 @@ class NeuralMemory(Module):
|
|
843
854
|
store_seq = None,
|
844
855
|
state: NeuralMemState | None = None,
|
845
856
|
prev_weights = None,
|
846
|
-
store_mask: Tensor | None = None
|
857
|
+
store_mask: Tensor | None = None,
|
858
|
+
return_surprises = False
|
847
859
|
):
|
848
860
|
is_multi_input = self.qkv_receives_diff_views
|
849
861
|
|
@@ -927,6 +939,7 @@ class NeuralMemory(Module):
|
|
927
939
|
|
928
940
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
929
941
|
|
942
|
+
surprises = None
|
930
943
|
gate = None
|
931
944
|
|
932
945
|
if exists(self.transition_gate):
|
@@ -937,13 +950,14 @@ class NeuralMemory(Module):
|
|
937
950
|
|
938
951
|
# store
|
939
952
|
|
940
|
-
next_updates, next_neural_mem_state = self.store_memories(
|
953
|
+
next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
|
941
954
|
store_seq_chunk,
|
942
955
|
weights,
|
943
956
|
seq_index = seq_index,
|
944
957
|
past_state = past_state,
|
945
958
|
prev_weights = prev_weights,
|
946
|
-
mask = maybe_store_mask
|
959
|
+
mask = maybe_store_mask,
|
960
|
+
return_surprises = True
|
947
961
|
)
|
948
962
|
|
949
963
|
weights = next_neural_mem_state.weights
|
@@ -952,6 +966,8 @@ class NeuralMemory(Module):
|
|
952
966
|
|
953
967
|
updates = accum_updates(updates, next_updates)
|
954
968
|
|
969
|
+
surprises = safe_cat((surprises, chunk_surprises), dim = -1)
|
970
|
+
|
955
971
|
if is_last and not update_after_final_store:
|
956
972
|
continue
|
957
973
|
|
@@ -986,4 +1002,9 @@ class NeuralMemory(Module):
|
|
986
1002
|
updates
|
987
1003
|
)
|
988
1004
|
|
989
|
-
|
1005
|
+
# returning
|
1006
|
+
|
1007
|
+
if not return_surprises:
|
1008
|
+
return retrieved, next_neural_mem_state
|
1009
|
+
|
1010
|
+
return retrieved, next_neural_mem_state, surprises
|
@@ -48,6 +48,7 @@ SLIDING_WINDOWS = True
|
|
48
48
|
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
49
49
|
MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
|
50
50
|
NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper
|
51
|
+
NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS
|
51
52
|
|
52
53
|
# experiment related
|
53
54
|
|
@@ -107,6 +108,7 @@ model = MemoryAsContextTransformer(
|
|
107
108
|
neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
|
108
109
|
neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
|
109
110
|
neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
|
111
|
+
neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW,
|
110
112
|
use_flex_attn = USE_FLEX_ATTN,
|
111
113
|
sliding_window_attn = SLIDING_WINDOWS,
|
112
114
|
neural_memory_model = neural_memory_model,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|