titans-pytorch 0.1.34__py3-none-any.whl → 0.1.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -479,7 +479,8 @@ class MemoryAsContextTransformer(Module):
479
479
  depth,
480
480
  segment_len,
481
481
  neural_memory_segment_len = None,
482
- neural_mem_gate_attn_output = True,
482
+ neural_mem_gate_attn_output = False,
483
+ neural_memory_add_value_residual = False,
483
484
  num_longterm_mem_tokens = 0,
484
485
  num_persist_mem_tokens = 0,
485
486
  dim_head = 64,
@@ -535,6 +536,11 @@ class MemoryAsContextTransformer(Module):
535
536
 
536
537
  self.weight_tie_memory_model = weight_tie_memory_model
537
538
 
539
+ # value residual learning for neural memory
540
+
541
+ is_first_mem = True
542
+ self.mem_add_value_residual = neural_memory_add_value_residual
543
+
538
544
  # mem, attn, and feedforward layers
539
545
 
540
546
  for layer in layers:
@@ -564,9 +570,11 @@ class MemoryAsContextTransformer(Module):
564
570
  dim = dim,
565
571
  chunk_size = self.neural_memory_segment_len,
566
572
  model = maybe_copy(neural_memory_model),
573
+ accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
567
574
  **neural_memory_kwargs
568
575
  )
569
576
 
577
+ is_first_mem = False
570
578
 
571
579
  ff = FeedForward(dim = dim, mult = ff_mult)
572
580
 
@@ -757,6 +765,8 @@ class MemoryAsContextTransformer(Module):
757
765
 
758
766
  value_residual = None
759
767
 
768
+ mem_value_residual = None
769
+
760
770
  # aux losses
761
771
 
762
772
  kv_recon_losses = self.zero
@@ -784,21 +794,28 @@ class MemoryAsContextTransformer(Module):
784
794
  mem_input, add_residual = mem_hyper_conn(x)
785
795
 
786
796
  if not is_inferencing:
787
- (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
797
+ (retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
788
798
  mem_input,
789
799
  return_aux_kv_loss = True,
800
+ return_values = True,
801
+ value_residual = mem_value_residual,
790
802
  prev_layer_updates = neural_memory_updates
791
803
  )
792
804
 
793
805
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
794
806
 
795
807
  else:
796
- retrieved, next_neural_mem_cache = mem.forward_inference(
808
+ (retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
797
809
  mem_input,
798
810
  state = next(neural_mem_caches, None),
811
+ return_values = True,
812
+ value_residual = mem_value_residual,
799
813
  prev_layer_updates = neural_memory_updates
800
814
  )
801
815
 
816
+ if self.mem_add_value_residual:
817
+ mem_value_residual = next_mem_value_residual
818
+
802
819
  if weight_tie_memory_model:
803
820
  neural_memory_updates = next_neural_mem_cache.updates
804
821
 
@@ -822,7 +822,9 @@ class NeuralMemory(Module):
822
822
  self,
823
823
  token: Tensor,
824
824
  state = None,
825
- prev_layer_updates: dict[str, Tensor] | None = None
825
+ prev_layer_updates: dict[str, Tensor] | None = None,
826
+ return_values = False,
827
+ value_residual = None,
826
828
  ):
827
829
 
828
830
  # unpack previous state
@@ -870,11 +872,12 @@ class NeuralMemory(Module):
870
872
 
871
873
  if store_seq_cache_len == self.chunk_size:
872
874
 
873
- next_updates, next_states, _ = self.store_memories(
875
+ next_updates, next_states, values = self.store_memories(
874
876
  cache_store_seq,
875
877
  weights,
876
878
  past_state = past_states,
877
879
  prev_layer_updates = prev_layer_updates,
880
+ value_residual = value_residual
878
881
  )
879
882
 
880
883
  updates = next_updates
@@ -888,7 +891,12 @@ class NeuralMemory(Module):
888
891
 
889
892
  next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
890
893
 
891
- return retrieved, next_state
894
+ output = (retrieved, next_state)
895
+
896
+ if return_values:
897
+ output = (*output, values)
898
+
899
+ return output
892
900
 
893
901
  def forward(
894
902
  self,
@@ -900,6 +908,7 @@ class NeuralMemory(Module):
900
908
  chunk_size = None,
901
909
  store_chunk_size = None,
902
910
  return_values = False,
911
+ value_residual = None,
903
912
  return_next_state = False,
904
913
  prev_layer_updates: dict[str, Tensor] | None = None
905
914
  ):
@@ -933,6 +942,7 @@ class NeuralMemory(Module):
933
942
  mem_model_weights,
934
943
  chunk_size = store_chunk_size,
935
944
  prev_layer_updates = prev_layer_updates,
945
+ value_residual = value_residual,
936
946
  return_aux_kv_loss = True
937
947
  )
938
948
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.34
3
+ Version: 0.1.35
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
4
+ titans_pytorch/neural_memory.py,sha256=wFOLFe3ViXiQfNvUiAGJ6BfiaDhr0BYDRDnLNMHWQhU,28938
5
+ titans_pytorch-0.1.35.dist-info/METADATA,sha256=5e5qPt4hAOhxhDWqdjutjJuUmht44zYq_KqgagKjqxE,6826
6
+ titans_pytorch-0.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.35.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.35.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=JvA4mhQaW9LD4j6boRUfLfjyzDCtjqybIr4Ajeio8n8,25708
4
- titans_pytorch/neural_memory.py,sha256=nNAxhkubuHCGs3bty_eA_yBhWqepPZJgKKvkWXO6IK4,28653
5
- titans_pytorch-0.1.34.dist-info/METADATA,sha256=pVgjCX_YTT9_5WPcFfXpoaBvzrg1-esvwS0kPpeJAYU,6826
6
- titans_pytorch-0.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.34.dist-info/RECORD,,