titans-pytorch 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +26 -4
- titans_pytorch/neural_memory.py +45 -11
- {titans_pytorch-0.1.32.dist-info → titans_pytorch-0.1.34.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.34.dist-info/RECORD +8 -0
- titans_pytorch-0.1.32.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.32.dist-info → titans_pytorch-0.1.34.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.32.dist-info → titans_pytorch-0.1.34.dist-info}/licenses/LICENSE +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
3
3
|
|
4
4
|
from math import ceil
|
5
|
+
from copy import deepcopy
|
5
6
|
from functools import partial
|
6
7
|
from collections import namedtuple
|
7
8
|
|
@@ -485,11 +486,13 @@ class MemoryAsContextTransformer(Module):
|
|
485
486
|
heads = 8,
|
486
487
|
ff_mult = 4,
|
487
488
|
num_residual_streams = 4,
|
489
|
+
neural_memory_model: Module | None = None,
|
488
490
|
neural_memory_kwargs: dict = dict(),
|
489
491
|
neural_memory_layers: tuple[int, ...] | None = None,
|
490
492
|
aux_kv_recon_loss_weight = 0.,
|
491
493
|
use_flex_attn = False,
|
492
|
-
sliding_window_attn = False
|
494
|
+
sliding_window_attn = False,
|
495
|
+
weight_tie_memory_model = False
|
493
496
|
):
|
494
497
|
super().__init__()
|
495
498
|
|
@@ -523,6 +526,15 @@ class MemoryAsContextTransformer(Module):
|
|
523
526
|
|
524
527
|
neural_memory_layers = default(neural_memory_layers, layers)
|
525
528
|
|
529
|
+
# weight tying neural memory model
|
530
|
+
|
531
|
+
maybe_copy = deepcopy if not weight_tie_memory_model else identity
|
532
|
+
|
533
|
+
if weight_tie_memory_model:
|
534
|
+
assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
|
535
|
+
|
536
|
+
self.weight_tie_memory_model = weight_tie_memory_model
|
537
|
+
|
526
538
|
# mem, attn, and feedforward layers
|
527
539
|
|
528
540
|
for layer in layers:
|
@@ -551,6 +563,7 @@ class MemoryAsContextTransformer(Module):
|
|
551
563
|
mem = NeuralMemory(
|
552
564
|
dim = dim,
|
553
565
|
chunk_size = self.neural_memory_segment_len,
|
566
|
+
model = maybe_copy(neural_memory_model),
|
554
567
|
**neural_memory_kwargs
|
555
568
|
)
|
556
569
|
|
@@ -683,7 +696,7 @@ class MemoryAsContextTransformer(Module):
|
|
683
696
|
|
684
697
|
# math
|
685
698
|
|
686
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
699
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
|
687
700
|
|
688
701
|
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
689
702
|
|
@@ -736,6 +749,10 @@ class MemoryAsContextTransformer(Module):
|
|
736
749
|
next_kv_caches = []
|
737
750
|
next_neural_mem_caches = []
|
738
751
|
|
752
|
+
# weight tied neural memory
|
753
|
+
|
754
|
+
neural_memory_updates = None
|
755
|
+
|
739
756
|
# value residual
|
740
757
|
|
741
758
|
value_residual = None
|
@@ -769,7 +786,8 @@ class MemoryAsContextTransformer(Module):
|
|
769
786
|
if not is_inferencing:
|
770
787
|
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
771
788
|
mem_input,
|
772
|
-
return_aux_kv_loss = True
|
789
|
+
return_aux_kv_loss = True,
|
790
|
+
prev_layer_updates = neural_memory_updates
|
773
791
|
)
|
774
792
|
|
775
793
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
@@ -777,9 +795,13 @@ class MemoryAsContextTransformer(Module):
|
|
777
795
|
else:
|
778
796
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
779
797
|
mem_input,
|
780
|
-
state = next(neural_mem_caches, None)
|
798
|
+
state = next(neural_mem_caches, None),
|
799
|
+
prev_layer_updates = neural_memory_updates
|
781
800
|
)
|
782
801
|
|
802
|
+
if weight_tie_memory_model:
|
803
|
+
neural_memory_updates = next_neural_mem_cache.updates
|
804
|
+
|
783
805
|
if self.gate_attn_output:
|
784
806
|
attn_out_gates = retrieved.sigmoid()
|
785
807
|
else:
|
titans_pytorch/neural_memory.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from collections import namedtuple
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn, cat, Tensor
|
@@ -33,6 +34,8 @@ w - num memory network weight parameters
|
|
33
34
|
|
34
35
|
LinearNoBias = partial(Linear, bias = False)
|
35
36
|
|
37
|
+
NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
|
38
|
+
|
36
39
|
# functions
|
37
40
|
|
38
41
|
def exists(v):
|
@@ -605,7 +608,7 @@ class NeuralMemory(Module):
|
|
605
608
|
# improvise (or perhaps correcting to) a solution
|
606
609
|
|
607
610
|
if exists(prev_layer_updates):
|
608
|
-
prev_layer_updates = TensorDict(
|
611
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
609
612
|
|
610
613
|
weights = weights + prev_layer_updates
|
611
614
|
|
@@ -657,6 +660,11 @@ class NeuralMemory(Module):
|
|
657
660
|
|
658
661
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
659
662
|
|
663
|
+
# flatten batch and time if surprise depends on previous layer memory model
|
664
|
+
|
665
|
+
if exists(prev_layer_updates):
|
666
|
+
weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
667
|
+
|
660
668
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
661
669
|
|
662
670
|
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
@@ -737,7 +745,8 @@ class NeuralMemory(Module):
|
|
737
745
|
self,
|
738
746
|
seq,
|
739
747
|
past_weights: dict[str, Tensor],
|
740
|
-
chunk_size = None
|
748
|
+
chunk_size = None,
|
749
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
741
750
|
):
|
742
751
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
743
752
|
batch, seq_len = seq.shape[:2]
|
@@ -760,6 +769,9 @@ class NeuralMemory(Module):
|
|
760
769
|
|
761
770
|
curr_weights = TensorDict(past_weights)
|
762
771
|
|
772
|
+
if exists(prev_layer_updates):
|
773
|
+
curr_weights = curr_weights + TensorDict(prev_layer_updates)
|
774
|
+
|
763
775
|
# sequence Float['b n d'] to queries
|
764
776
|
|
765
777
|
queries = self.to_queries(seq)
|
@@ -810,6 +822,7 @@ class NeuralMemory(Module):
|
|
810
822
|
self,
|
811
823
|
token: Tensor,
|
812
824
|
state = None,
|
825
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
813
826
|
):
|
814
827
|
|
815
828
|
# unpack previous state
|
@@ -838,7 +851,7 @@ class NeuralMemory(Module):
|
|
838
851
|
if curr_seq_len < self.chunk_size:
|
839
852
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
840
853
|
|
841
|
-
return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
|
854
|
+
return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
842
855
|
|
843
856
|
# store if storage sequence cache hits the chunk size
|
844
857
|
|
@@ -848,13 +861,20 @@ class NeuralMemory(Module):
|
|
848
861
|
if not exists(updates):
|
849
862
|
updates = weights.clone().zero_()
|
850
863
|
updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
|
864
|
+
else:
|
865
|
+
updates = updates.apply(lambda t: t[:, -1:])
|
866
|
+
|
867
|
+
if exists(prev_layer_updates):
|
868
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
869
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
|
851
870
|
|
852
871
|
if store_seq_cache_len == self.chunk_size:
|
853
872
|
|
854
873
|
next_updates, next_states, _ = self.store_memories(
|
855
874
|
cache_store_seq,
|
856
875
|
weights,
|
857
|
-
past_state = past_states
|
876
|
+
past_state = past_states,
|
877
|
+
prev_layer_updates = prev_layer_updates,
|
858
878
|
)
|
859
879
|
|
860
880
|
updates = next_updates
|
@@ -866,7 +886,7 @@ class NeuralMemory(Module):
|
|
866
886
|
|
867
887
|
# next state tuple
|
868
888
|
|
869
|
-
next_state = (curr_seq_len, cache_store_seq, next_states, updates)
|
889
|
+
next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
870
890
|
|
871
891
|
return retrieved, next_state
|
872
892
|
|
@@ -880,7 +900,8 @@ class NeuralMemory(Module):
|
|
880
900
|
chunk_size = None,
|
881
901
|
store_chunk_size = None,
|
882
902
|
return_values = False,
|
883
|
-
return_next_state = False
|
903
|
+
return_next_state = False,
|
904
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
884
905
|
):
|
885
906
|
batch, seq_len = seq.shape[:2]
|
886
907
|
|
@@ -899,15 +920,30 @@ class NeuralMemory(Module):
|
|
899
920
|
if not exists(mem_model_weights):
|
900
921
|
mem_model_weights = self.init_weights()
|
901
922
|
|
923
|
+
# store
|
924
|
+
|
902
925
|
store_seq = default(store_seq, seq)
|
903
926
|
|
904
927
|
store_seq_len = store_seq.shape[-2]
|
905
928
|
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
906
929
|
remainder = store_seq_len % store_chunk_size
|
907
930
|
|
908
|
-
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
931
|
+
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
932
|
+
store_seq,
|
933
|
+
mem_model_weights,
|
934
|
+
chunk_size = store_chunk_size,
|
935
|
+
prev_layer_updates = prev_layer_updates,
|
936
|
+
return_aux_kv_loss = True
|
937
|
+
)
|
938
|
+
|
939
|
+
# retrieve
|
909
940
|
|
910
|
-
retrieved = self.retrieve_memories(
|
941
|
+
retrieved = self.retrieve_memories(
|
942
|
+
seq,
|
943
|
+
mem_model_weights + updates,
|
944
|
+
chunk_size = chunk_size,
|
945
|
+
prev_layer_updates = prev_layer_updates
|
946
|
+
)
|
911
947
|
|
912
948
|
# determine state for the storing of memories
|
913
949
|
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
@@ -917,9 +953,7 @@ class NeuralMemory(Module):
|
|
917
953
|
if remainder > 0:
|
918
954
|
cache_store_seq = store_seq[:, -remainder:]
|
919
955
|
|
920
|
-
|
921
|
-
|
922
|
-
next_store_state = (seq_len, cache_store_seq, next_state, updates)
|
956
|
+
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
923
957
|
|
924
958
|
output = (retrieved, next_store_state)
|
925
959
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=JvA4mhQaW9LD4j6boRUfLfjyzDCtjqybIr4Ajeio8n8,25708
|
4
|
+
titans_pytorch/neural_memory.py,sha256=nNAxhkubuHCGs3bty_eA_yBhWqepPZJgKKvkWXO6IK4,28653
|
5
|
+
titans_pytorch-0.1.34.dist-info/METADATA,sha256=pVgjCX_YTT9_5WPcFfXpoaBvzrg1-esvwS0kPpeJAYU,6826
|
6
|
+
titans_pytorch-0.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.34.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
|
4
|
-
titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
|
5
|
-
titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
|
6
|
-
titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.32.dist-info/RECORD,,
|
File without changes
|
File without changes
|