titans-pytorch 0.3.25__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.25
3
+ Version: 0.4.0
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.25"
3
+ version = "0.4.0"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -46,7 +46,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False
46
46
 
47
47
  # einstein notation related
48
48
 
49
- from einops import repeat, rearrange, pack, unpack
49
+ from einops import repeat, rearrange, pack, unpack, einsum
50
50
  from einops.layers.torch import Rearrange
51
51
 
52
52
  # b - batch
@@ -521,9 +521,7 @@ class MemoryAsContextTransformer(Module):
521
521
  self.sliding_window_attn = sliding_window_attn
522
522
  self.attn_window_size = segment_len + num_longterm_mem_tokens
523
523
 
524
- # hyper conection
525
-
526
- assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
524
+ # hyper connection
527
525
 
528
526
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
529
527
 
@@ -560,17 +558,28 @@ class MemoryAsContextTransformer(Module):
560
558
  )
561
559
 
562
560
  mem = None
561
+ mem_qkv_layer_selector = None
563
562
  mem_hyper_conn = None
564
563
 
565
564
  if layer in neural_memory_layers:
566
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
565
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
566
+
567
+ if not is_first and neural_memory_qkv_receives_diff_views:
568
+ num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
569
+
570
+ mem_qkv_layer_selector = nn.Sequential(
571
+ nn.RMSNorm(dim),
572
+ nn.Linear(dim, 3 * num_layer_choices),
573
+ Rearrange('... (views layers) -> views ... layers', views = 3),
574
+ nn.Softmax(dim = -1)
575
+ )
567
576
 
568
577
  mem = NeuralMemory(
569
578
  dim = dim,
570
579
  chunk_size = self.neural_memory_segment_len,
571
580
  batch_size = neural_memory_batch_size,
572
581
  model = deepcopy(neural_memory_model),
573
- qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
582
+ qkv_receives_diff_views = True,
574
583
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
575
584
  **neural_memory_kwargs
576
585
  )
@@ -581,9 +590,12 @@ class MemoryAsContextTransformer(Module):
581
590
 
582
591
  self.layers.append(ModuleList([
583
592
  mem_hyper_conn,
593
+ init_hyper_conn(),
594
+ init_hyper_conn(),
595
+ mem_qkv_layer_selector,
584
596
  mem,
585
- init_hyper_conn(branch = attn),
586
- init_hyper_conn(branch = ff)
597
+ attn,
598
+ ff,
587
599
  ]))
588
600
 
589
601
  self.norm = nn.RMSNorm(dim)
@@ -763,6 +775,10 @@ class MemoryAsContextTransformer(Module):
763
775
 
764
776
  mem_weight_residual = None
765
777
 
778
+ # layers for the neural mem to select the qkv inputs from
779
+
780
+ mem_input_layers = []
781
+
766
782
  # when inferencing, only do one token at a time
767
783
 
768
784
  if is_inferencing:
@@ -773,7 +789,7 @@ class MemoryAsContextTransformer(Module):
773
789
 
774
790
  x = self.expand_streams(x)
775
791
 
776
- for mem_hyper_conn, mem, attn, ff in self.layers:
792
+ for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
777
793
 
778
794
  retrieved = None
779
795
  attn_out_gates = None
@@ -785,8 +801,19 @@ class MemoryAsContextTransformer(Module):
785
801
 
786
802
  mem_input, add_residual = mem_hyper_conn(x)
787
803
 
804
+ if not exists(mem_qkv_layer_selector):
805
+ qkv_mem_input = stack((mem_input, mem_input, mem_input))
806
+ else:
807
+ layers_to_choose_from = stack((mem_input, *mem_input_layers))
808
+
809
+ # let the current `mem_input` select the 3 layers for qkv
810
+
811
+ selected = mem_qkv_layer_selector(mem_input)
812
+
813
+ qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
814
+
788
815
  retrieved, next_neural_mem_cache = mem.forward(
789
- mem_input,
816
+ qkv_mem_input,
790
817
  state = next(neural_mem_caches, None),
791
818
  prev_weights = mem_weight_residual
792
819
  )
@@ -801,8 +828,12 @@ class MemoryAsContextTransformer(Module):
801
828
 
802
829
  # attention
803
830
 
804
- x, (values, next_kv_cache) = attn(
805
- x,
831
+ attn_in, add_residual = attn_hyper_conn(x)
832
+
833
+ mem_input_layers.append(attn_in)
834
+
835
+ attn_out, (values, next_kv_cache) = attn(
836
+ attn_in,
806
837
  value_residual = value_residual,
807
838
  disable_flex_attn = disable_flex_attn,
808
839
  flex_attn_fn = flex_attn_fn,
@@ -810,8 +841,12 @@ class MemoryAsContextTransformer(Module):
810
841
  cache = next(kv_caches, None)
811
842
  )
812
843
 
844
+ mem_input_layers.append(attn_out)
845
+
813
846
  value_residual = default(value_residual, values)
814
847
 
848
+ x = add_residual(attn_out)
849
+
815
850
  # caches
816
851
 
817
852
  next_kv_caches.append(next_kv_cache)
@@ -819,7 +854,15 @@ class MemoryAsContextTransformer(Module):
819
854
 
820
855
  # feedforward
821
856
 
822
- x = ff(x)
857
+ ff_in, add_ff_residual = ff_hyper_conn(x)
858
+
859
+ mem_input_layers.append(ff_in)
860
+
861
+ ff_out = ff(ff_in)
862
+
863
+ mem_input_layers.append(ff_out)
864
+
865
+ x = add_ff_residual(ff_out)
823
866
 
824
867
  # taking care of cache first
825
868
  # for early return when processing long term mem tokens during inference
@@ -48,6 +48,7 @@ SLIDING_WINDOWS = True
48
48
  STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
49
49
  MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
50
50
  NEURAL_MEM_WEIGHT_RESIDUAL = True # learning to accept contributions from the weights of the previous neural mem layer brings about significant improvements. this was improvised and not in the paper, but inspired by the value residual learning free lunch paper
51
+ NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW = True # will allow the neural memory to select what layers from which to derive queries / keys / values, effectively allowing it to graft itself to the transformer in any way to be beneficial. this is to address an issue from a phd student who noted that the mem network is learning nothing more than wk @ wv. this also generalizes all possible ways to connect the neural memory to a transformer, a sort of NAS
51
52
 
52
53
  # experiment related
53
54
 
@@ -107,6 +108,7 @@ model = MemoryAsContextTransformer(
107
108
  neural_memory_batch_size = NEURAL_MEM_BATCH_SIZE,
108
109
  neural_mem_gate_attn_output = NEURAL_MEM_GATE_ATTN_OUTPUT,
109
110
  neural_mem_weight_residual = NEURAL_MEM_WEIGHT_RESIDUAL,
111
+ neural_memory_qkv_receives_diff_views = NEURAL_MEM_QKV_RECEIVES_DIFF_VIEW,
110
112
  use_flex_attn = USE_FLEX_ATTN,
111
113
  sliding_window_attn = SLIDING_WINDOWS,
112
114
  neural_memory_model = neural_memory_model,
File without changes
File without changes
File without changes