titans-pytorch 0.3.24__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,7 +46,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False
46
46
 
47
47
  # einstein notation related
48
48
 
49
- from einops import repeat, rearrange, pack, unpack
49
+ from einops import repeat, rearrange, pack, unpack, einsum
50
50
  from einops.layers.torch import Rearrange
51
51
 
52
52
  # b - batch
@@ -521,9 +521,7 @@ class MemoryAsContextTransformer(Module):
521
521
  self.sliding_window_attn = sliding_window_attn
522
522
  self.attn_window_size = segment_len + num_longterm_mem_tokens
523
523
 
524
- # hyper conection
525
-
526
- assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
524
+ # hyper connection
527
525
 
528
526
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
529
527
 
@@ -560,17 +558,28 @@ class MemoryAsContextTransformer(Module):
560
558
  )
561
559
 
562
560
  mem = None
561
+ mem_qkv_layer_selector = None
563
562
  mem_hyper_conn = None
564
563
 
565
564
  if layer in neural_memory_layers:
566
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
565
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
566
+
567
+ if not is_first and neural_memory_qkv_receives_diff_views:
568
+ num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
569
+
570
+ mem_qkv_layer_selector = nn.Sequential(
571
+ nn.RMSNorm(dim),
572
+ nn.Linear(dim, 3 * num_layer_choices),
573
+ Rearrange('... (views layers) -> views ... layers', views = 3),
574
+ nn.Softmax(dim = -1)
575
+ )
567
576
 
568
577
  mem = NeuralMemory(
569
578
  dim = dim,
570
579
  chunk_size = self.neural_memory_segment_len,
571
580
  batch_size = neural_memory_batch_size,
572
581
  model = deepcopy(neural_memory_model),
573
- qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
582
+ qkv_receives_diff_views = True,
574
583
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
575
584
  **neural_memory_kwargs
576
585
  )
@@ -581,9 +590,12 @@ class MemoryAsContextTransformer(Module):
581
590
 
582
591
  self.layers.append(ModuleList([
583
592
  mem_hyper_conn,
593
+ init_hyper_conn(),
594
+ init_hyper_conn(),
595
+ mem_qkv_layer_selector,
584
596
  mem,
585
- init_hyper_conn(branch = attn),
586
- init_hyper_conn(branch = ff)
597
+ attn,
598
+ ff,
587
599
  ]))
588
600
 
589
601
  self.norm = nn.RMSNorm(dim)
@@ -763,6 +775,10 @@ class MemoryAsContextTransformer(Module):
763
775
 
764
776
  mem_weight_residual = None
765
777
 
778
+ # layers for the neural mem to select the qkv inputs from
779
+
780
+ mem_input_layers = []
781
+
766
782
  # when inferencing, only do one token at a time
767
783
 
768
784
  if is_inferencing:
@@ -773,7 +789,7 @@ class MemoryAsContextTransformer(Module):
773
789
 
774
790
  x = self.expand_streams(x)
775
791
 
776
- for mem_hyper_conn, mem, attn, ff in self.layers:
792
+ for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
777
793
 
778
794
  retrieved = None
779
795
  attn_out_gates = None
@@ -785,8 +801,19 @@ class MemoryAsContextTransformer(Module):
785
801
 
786
802
  mem_input, add_residual = mem_hyper_conn(x)
787
803
 
804
+ if not exists(mem_qkv_layer_selector):
805
+ qkv_mem_input = stack((mem_input, mem_input, mem_input))
806
+ else:
807
+ layers_to_choose_from = stack((mem_input, *mem_input_layers))
808
+
809
+ # let the current `mem_input` select the 3 layers for qkv
810
+
811
+ selected = mem_qkv_layer_selector(mem_input)
812
+
813
+ qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
814
+
788
815
  retrieved, next_neural_mem_cache = mem.forward(
789
- mem_input,
816
+ qkv_mem_input,
790
817
  state = next(neural_mem_caches, None),
791
818
  prev_weights = mem_weight_residual
792
819
  )
@@ -801,8 +828,12 @@ class MemoryAsContextTransformer(Module):
801
828
 
802
829
  # attention
803
830
 
804
- x, (values, next_kv_cache) = attn(
805
- x,
831
+ attn_in, add_residual = attn_hyper_conn(x)
832
+
833
+ mem_input_layers.append(attn_in)
834
+
835
+ attn_out, (values, next_kv_cache) = attn(
836
+ attn_in,
806
837
  value_residual = value_residual,
807
838
  disable_flex_attn = disable_flex_attn,
808
839
  flex_attn_fn = flex_attn_fn,
@@ -810,8 +841,12 @@ class MemoryAsContextTransformer(Module):
810
841
  cache = next(kv_caches, None)
811
842
  )
812
843
 
844
+ mem_input_layers.append(attn_out)
845
+
813
846
  value_residual = default(value_residual, values)
814
847
 
848
+ x = add_residual(attn_out)
849
+
815
850
  # caches
816
851
 
817
852
  next_kv_caches.append(next_kv_cache)
@@ -819,7 +854,15 @@ class MemoryAsContextTransformer(Module):
819
854
 
820
855
  # feedforward
821
856
 
822
- x = ff(x)
857
+ ff_in, add_ff_residual = ff_hyper_conn(x)
858
+
859
+ mem_input_layers.append(ff_in)
860
+
861
+ ff_out = ff(ff_in)
862
+
863
+ mem_input_layers.append(ff_out)
864
+
865
+ x = add_ff_residual(ff_out)
823
866
 
824
867
  # taking care of cache first
825
868
  # for early return when processing long term mem tokens during inference
@@ -524,7 +524,8 @@ class NeuralMemory(Module):
524
524
  weights: dict[str, Tensor] | None = None,
525
525
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
526
526
  seq_index = 0,
527
- prev_weights = None
527
+ prev_weights = None,
528
+ mask: Tensor | None = None,
528
529
  ):
529
530
  if self.qkv_receives_diff_views:
530
531
  _, batch, seq_len = seq.shape[:3]
@@ -612,6 +613,14 @@ class NeuralMemory(Module):
612
613
 
613
614
  adaptive_lr = rearrange(adaptive_lr, 'b (n c u) -> (b n) (c u)', c = chunk_size, u = num_updates)
614
615
 
616
+ # optionally a storing memories mask can be passed in. if False, will set the learning rate to 0. for those positions
617
+
618
+ if exists(mask):
619
+ mask = mask[..., :round_down_seq_len]
620
+ mask = repeat(mask, 'b (n c) -> (b h n) (c u)', h = heads, u = num_updates, c = chunk_size)
621
+
622
+ adaptive_lr = torch.where(mask, adaptive_lr, 0.)
623
+
615
624
  # maybe add previous layer weight
616
625
 
617
626
  assert xnor(exists(self.to_learned_weight_residual_mix), exists(prev_weights))
@@ -833,7 +842,8 @@ class NeuralMemory(Module):
833
842
  seq,
834
843
  store_seq = None,
835
844
  state: NeuralMemState | None = None,
836
- prev_weights = None
845
+ prev_weights = None,
846
+ store_mask: Tensor | None = None
837
847
  ):
838
848
  is_multi_input = self.qkv_receives_diff_views
839
849
 
@@ -910,6 +920,11 @@ class NeuralMemory(Module):
910
920
 
911
921
  store_seqs = store_seq.split(split_sizes, dim = -2)
912
922
 
923
+ if exists(store_mask):
924
+ store_masks = store_mask.split(split_sizes, dim = -1)
925
+ else:
926
+ store_masks = (None,) * len(split_sizes)
927
+
913
928
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
914
929
 
915
930
  gate = None
@@ -917,7 +932,7 @@ class NeuralMemory(Module):
917
932
  if exists(self.transition_gate):
918
933
  gate = self.transition_gate.sigmoid()
919
934
 
920
- for ind, store_seq_chunk in enumerate(store_seqs):
935
+ for ind, (store_seq_chunk, maybe_store_mask) in enumerate(zip(store_seqs, store_masks)):
921
936
  is_last = ind == (len(store_seqs) - 1)
922
937
 
923
938
  # store
@@ -927,7 +942,8 @@ class NeuralMemory(Module):
927
942
  weights,
928
943
  seq_index = seq_index,
929
944
  past_state = past_state,
930
- prev_weights = prev_weights
945
+ prev_weights = prev_weights,
946
+ mask = maybe_store_mask
931
947
  )
932
948
 
933
949
  weights = next_neural_mem_state.weights
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.24
3
+ Version: 0.4.0
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=uh5NbtAAzfPeZPFe7uhgnpUF6qyP0zjP0eXPIgY5pfc,31929
6
+ titans_pytorch-0.4.0.dist-info/METADATA,sha256=uOklaPv-y-eSpgnvrgVZ-ZL4TpeBg7r_EJxwJbdKyO0,6816
7
+ titans_pytorch-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=grD327B3OCIy7d23jNUWIoUo1bIgXUqD26dXWCjdi28,25565
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=haSepQdfGsQdoo9Yk5agvRR91kTu8kgkXpBmBZaH8WI,31237
6
- titans_pytorch-0.3.24.dist-info/METADATA,sha256=0-WHTKNXZpESWfOMSOO8MiWddqDoSRP1lifsfgHmewo,6817
7
- titans_pytorch-0.3.24.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.24.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.24.dist-info/RECORD,,