titans-pytorch 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  from math import ceil
5
+ from copy import deepcopy
5
6
  from functools import partial
6
7
  from collections import namedtuple
7
8
 
@@ -485,11 +486,13 @@ class MemoryAsContextTransformer(Module):
485
486
  heads = 8,
486
487
  ff_mult = 4,
487
488
  num_residual_streams = 4,
489
+ neural_memory_model: Module | None = None,
488
490
  neural_memory_kwargs: dict = dict(),
489
491
  neural_memory_layers: tuple[int, ...] | None = None,
490
492
  aux_kv_recon_loss_weight = 0.,
491
493
  use_flex_attn = False,
492
- sliding_window_attn = False
494
+ sliding_window_attn = False,
495
+ weight_tie_memory_model = False
493
496
  ):
494
497
  super().__init__()
495
498
 
@@ -523,6 +526,15 @@ class MemoryAsContextTransformer(Module):
523
526
 
524
527
  neural_memory_layers = default(neural_memory_layers, layers)
525
528
 
529
+ # weight tying neural memory model
530
+
531
+ maybe_copy = deepcopy if not weight_tie_memory_model else identity
532
+
533
+ if weight_tie_memory_model:
534
+ assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
535
+
536
+ self.weight_tie_memory_model = weight_tie_memory_model
537
+
526
538
  # mem, attn, and feedforward layers
527
539
 
528
540
  for layer in layers:
@@ -551,6 +563,7 @@ class MemoryAsContextTransformer(Module):
551
563
  mem = NeuralMemory(
552
564
  dim = dim,
553
565
  chunk_size = self.neural_memory_segment_len,
566
+ model = maybe_copy(neural_memory_model),
554
567
  **neural_memory_kwargs
555
568
  )
556
569
 
@@ -683,7 +696,7 @@ class MemoryAsContextTransformer(Module):
683
696
 
684
697
  # math
685
698
 
686
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
699
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
687
700
 
688
701
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
689
702
 
@@ -736,6 +749,10 @@ class MemoryAsContextTransformer(Module):
736
749
  next_kv_caches = []
737
750
  next_neural_mem_caches = []
738
751
 
752
+ # weight tied neural memory
753
+
754
+ neural_memory_updates = None
755
+
739
756
  # value residual
740
757
 
741
758
  value_residual = None
@@ -769,7 +786,8 @@ class MemoryAsContextTransformer(Module):
769
786
  if not is_inferencing:
770
787
  (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
771
788
  mem_input,
772
- return_aux_kv_loss = True
789
+ return_aux_kv_loss = True,
790
+ prev_layer_updates = neural_memory_updates
773
791
  )
774
792
 
775
793
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
@@ -777,9 +795,13 @@ class MemoryAsContextTransformer(Module):
777
795
  else:
778
796
  retrieved, next_neural_mem_cache = mem.forward_inference(
779
797
  mem_input,
780
- state = next(neural_mem_caches, None)
798
+ state = next(neural_mem_caches, None),
799
+ prev_layer_updates = neural_memory_updates
781
800
  )
782
801
 
802
+ if weight_tie_memory_model:
803
+ neural_memory_updates = next_neural_mem_cache.updates
804
+
783
805
  if self.gate_attn_output:
784
806
  attn_out_gates = retrieved.sigmoid()
785
807
  else:
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from collections import namedtuple
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, Tensor
@@ -33,6 +34,8 @@ w - num memory network weight parameters
33
34
 
34
35
  LinearNoBias = partial(Linear, bias = False)
35
36
 
37
+ NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
38
+
36
39
  # functions
37
40
 
38
41
  def exists(v):
@@ -605,7 +608,7 @@ class NeuralMemory(Module):
605
608
  # improvise (or perhaps correcting to) a solution
606
609
 
607
610
  if exists(prev_layer_updates):
608
- prev_layer_updates = TensorDict(weights)
611
+ prev_layer_updates = TensorDict(prev_layer_updates)
609
612
 
610
613
  weights = weights + prev_layer_updates
611
614
 
@@ -657,6 +660,11 @@ class NeuralMemory(Module):
657
660
 
658
661
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
659
662
 
663
+ # flatten batch and time if surprise depends on previous layer memory model
664
+
665
+ if exists(prev_layer_updates):
666
+ weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
667
+
660
668
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
661
669
 
662
670
  grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
@@ -737,7 +745,8 @@ class NeuralMemory(Module):
737
745
  self,
738
746
  seq,
739
747
  past_weights: dict[str, Tensor],
740
- chunk_size = None
748
+ chunk_size = None,
749
+ prev_layer_updates: dict[str, Tensor] | None = None
741
750
  ):
742
751
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
743
752
  batch, seq_len = seq.shape[:2]
@@ -760,6 +769,9 @@ class NeuralMemory(Module):
760
769
 
761
770
  curr_weights = TensorDict(past_weights)
762
771
 
772
+ if exists(prev_layer_updates):
773
+ curr_weights = curr_weights + TensorDict(prev_layer_updates)
774
+
763
775
  # sequence Float['b n d'] to queries
764
776
 
765
777
  queries = self.to_queries(seq)
@@ -810,6 +822,7 @@ class NeuralMemory(Module):
810
822
  self,
811
823
  token: Tensor,
812
824
  state = None,
825
+ prev_layer_updates: dict[str, Tensor] | None = None
813
826
  ):
814
827
 
815
828
  # unpack previous state
@@ -838,7 +851,7 @@ class NeuralMemory(Module):
838
851
  if curr_seq_len < self.chunk_size:
839
852
  empty_mem = self.init_empty_memory_embed(batch, 1)
840
853
 
841
- return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
854
+ return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
842
855
 
843
856
  # store if storage sequence cache hits the chunk size
844
857
 
@@ -848,13 +861,20 @@ class NeuralMemory(Module):
848
861
  if not exists(updates):
849
862
  updates = weights.clone().zero_()
850
863
  updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
864
+ else:
865
+ updates = updates.apply(lambda t: t[:, -1:])
866
+
867
+ if exists(prev_layer_updates):
868
+ prev_layer_updates = TensorDict(prev_layer_updates)
869
+ prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
851
870
 
852
871
  if store_seq_cache_len == self.chunk_size:
853
872
 
854
873
  next_updates, next_states, _ = self.store_memories(
855
874
  cache_store_seq,
856
875
  weights,
857
- past_state = past_states
876
+ past_state = past_states,
877
+ prev_layer_updates = prev_layer_updates,
858
878
  )
859
879
 
860
880
  updates = next_updates
@@ -866,7 +886,7 @@ class NeuralMemory(Module):
866
886
 
867
887
  # next state tuple
868
888
 
869
- next_state = (curr_seq_len, cache_store_seq, next_states, updates)
889
+ next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
870
890
 
871
891
  return retrieved, next_state
872
892
 
@@ -880,7 +900,8 @@ class NeuralMemory(Module):
880
900
  chunk_size = None,
881
901
  store_chunk_size = None,
882
902
  return_values = False,
883
- return_next_state = False
903
+ return_next_state = False,
904
+ prev_layer_updates: dict[str, Tensor] | None = None
884
905
  ):
885
906
  batch, seq_len = seq.shape[:2]
886
907
 
@@ -899,15 +920,30 @@ class NeuralMemory(Module):
899
920
  if not exists(mem_model_weights):
900
921
  mem_model_weights = self.init_weights()
901
922
 
923
+ # store
924
+
902
925
  store_seq = default(store_seq, seq)
903
926
 
904
927
  store_seq_len = store_seq.shape[-2]
905
928
  store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
929
  remainder = store_seq_len % store_chunk_size
907
930
 
908
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
931
+ (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
932
+ store_seq,
933
+ mem_model_weights,
934
+ chunk_size = store_chunk_size,
935
+ prev_layer_updates = prev_layer_updates,
936
+ return_aux_kv_loss = True
937
+ )
938
+
939
+ # retrieve
909
940
 
910
- retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
941
+ retrieved = self.retrieve_memories(
942
+ seq,
943
+ mem_model_weights + updates,
944
+ chunk_size = chunk_size,
945
+ prev_layer_updates = prev_layer_updates
946
+ )
911
947
 
912
948
  # determine state for the storing of memories
913
949
  # for transformer-xl like training with neural memory as well as inferencing with initial prompt
@@ -917,9 +953,7 @@ class NeuralMemory(Module):
917
953
  if remainder > 0:
918
954
  cache_store_seq = store_seq[:, -remainder:]
919
955
 
920
- updates = updates.apply(lambda t: t[:, -1:])
921
-
922
- next_store_state = (seq_len, cache_store_seq, next_state, updates)
956
+ next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
923
957
 
924
958
  output = (retrieved, next_store_state)
925
959
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.32
3
+ Version: 0.1.34
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=JvA4mhQaW9LD4j6boRUfLfjyzDCtjqybIr4Ajeio8n8,25708
4
+ titans_pytorch/neural_memory.py,sha256=nNAxhkubuHCGs3bty_eA_yBhWqepPZJgKKvkWXO6IK4,28653
5
+ titans_pytorch-0.1.34.dist-info/METADATA,sha256=pVgjCX_YTT9_5WPcFfXpoaBvzrg1-esvwS0kPpeJAYU,6826
6
+ titans_pytorch-0.1.34.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.34.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.34.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
4
- titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
5
- titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
6
- titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.32.dist-info/RECORD,,