titans-pytorch 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -491,6 +491,7 @@ class MemoryAsContextTransformer(Module):
491
491
  neural_memory_layers: tuple[int, ...] | None = None,
492
492
  use_flex_attn = False,
493
493
  sliding_window_attn = False,
494
+ neural_mem_weight_residual = False
494
495
  ):
495
496
  super().__init__()
496
497
 
@@ -524,6 +525,8 @@ class MemoryAsContextTransformer(Module):
524
525
 
525
526
  neural_memory_layers = default(neural_memory_layers, layers)
526
527
 
528
+ self.neural_mem_weight_residual = neural_mem_weight_residual
529
+
527
530
  # mem, attn, and feedforward layers
528
531
 
529
532
  for layer in layers:
@@ -739,6 +742,10 @@ class MemoryAsContextTransformer(Module):
739
742
 
740
743
  value_residual = None
741
744
 
745
+ # neural mem weight residual
746
+
747
+ mem_weight_residual = None
748
+
742
749
  # when inferencing, only do one token at a time
743
750
 
744
751
  if is_inferencing:
@@ -761,16 +768,14 @@ class MemoryAsContextTransformer(Module):
761
768
 
762
769
  mem_input, add_residual = mem_hyper_conn(x)
763
770
 
764
- if not is_inferencing:
765
- retrieved, next_neural_mem_cache = mem(
766
- mem_input
767
- )
771
+ retrieved, next_neural_mem_cache = mem.forward(
772
+ mem_input,
773
+ state = next(neural_mem_caches, None),
774
+ prev_weights = mem_weight_residual
775
+ )
768
776
 
769
- else:
770
- (retrieved, next_neural_mem_cache) = mem.forward_inference(
771
- mem_input,
772
- state = next(neural_mem_caches, None),
773
- )
777
+ if self.neural_mem_weight_residual:
778
+ mem_weight_residual = next_neural_mem_cache.updates
774
779
 
775
780
  if self.gate_attn_output:
776
781
  attn_out_gates = retrieved.sigmoid()
@@ -494,7 +494,8 @@ class NeuralMemory(Module):
494
494
  seq,
495
495
  weights: dict[str, Tensor] | None = None,
496
496
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
497
- seq_index = 0
497
+ seq_index = 0,
498
+ prev_weights = None
498
499
  ):
499
500
  batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
500
501
 
@@ -560,6 +561,12 @@ class NeuralMemory(Module):
560
561
 
561
562
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
562
563
 
564
+ # maybe add previous layer weight
565
+
566
+ if exists(prev_weights):
567
+ prev_weights = prev_weights.apply(lambda t: t[:, -1:])
568
+ weights_for_surprise = weights_for_surprise + prev_weights
569
+
563
570
  # flatten batch and time if surprise depends on previous layer memory model
564
571
 
565
572
  weights_for_surprise = rearrange_dict_values(weights_for_surprise, 'b n ... -> (b n) ...')
@@ -655,15 +662,19 @@ class NeuralMemory(Module):
655
662
  self,
656
663
  seq,
657
664
  past_weights: dict[str, Tensor],
665
+ chunk_size = None,
666
+ need_pad = True
658
667
  ):
659
- chunk_size = self.retrieve_chunk_size
668
+ chunk_size = default(chunk_size, self.retrieve_chunk_size)
660
669
  batch, seq_len = seq.shape[:2]
661
670
 
662
671
  seq = self.retrieve_norm(seq)
663
672
 
664
- needs_pad = chunk_size > 1
673
+ need_pad = need_pad or chunk_size > 1
674
+
675
+ if need_pad:
676
+ seq = pad_at_dim(seq, (1, 0), dim = 1)
665
677
 
666
- seq = pad_at_dim(seq, (1, 0), dim = 1)
667
678
  seq_len_plus_one = seq.shape[-2]
668
679
 
669
680
  next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
@@ -718,94 +729,37 @@ class NeuralMemory(Module):
718
729
 
719
730
  # restore, pad with empty memory embed
720
731
 
721
- values = values[:, 1:(seq_len + 1)]
722
-
723
- return values
724
-
725
- @torch.no_grad()
726
- def forward_inference(
727
- self,
728
- token: Tensor,
729
- state: NeuralMemCache | None = None,
730
- ):
731
- # unpack previous state
732
-
733
- if not exists(state):
734
- state = (0, None, None, None, None)
735
-
736
- seq_index, weights, cache_store_seq, past_states, updates = state
737
-
738
- curr_seq_len = seq_index + 1
739
- batch = token.shape[0]
740
-
741
- if token.ndim == 2:
742
- token = rearrange(token, 'b d -> b 1 d')
743
-
744
- assert token.shape[1] == 1
745
-
746
- # increment the sequence cache which is at most the chunk size
747
-
748
- cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
749
-
750
- # early return empty memory, when no memories are stored for steps < first chunk size
751
-
752
- if curr_seq_len < self.chunk_size:
753
- retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
754
-
755
- output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
756
-
757
- return output
758
-
759
- # store if storage sequence cache hits the chunk size
732
+ if need_pad:
733
+ values = values[:, 1:]
760
734
 
761
- next_states = past_states
762
- store_seq_cache_len = cache_store_seq.shape[-2]
763
-
764
- if not exists(updates):
765
- updates = weights.clone().zero_()
766
- updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
767
- else:
768
- updates = updates.apply(lambda t: t[:, -1:])
769
-
770
- if store_seq_cache_len == self.chunk_size:
771
-
772
- next_updates, store_state = self.store_memories(
773
- cache_store_seq,
774
- weights,
775
- past_state = past_states,
776
- )
777
-
778
- updates = next_updates
779
- cache_store_seq = None
780
- next_states = store_state.states
781
-
782
- # retrieve
783
-
784
- retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
785
-
786
- # next state tuple
787
-
788
- next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
789
-
790
- return retrieved, next_store_state
735
+ return values[:, :seq_len]
791
736
 
792
737
  def forward(
793
738
  self,
794
739
  seq,
795
740
  store_seq = None,
796
741
  state: NeuralMemCache | None = None,
742
+ prev_weights = None
797
743
  ):
744
+ if seq.ndim == 2:
745
+ seq = rearrange(seq, 'b d -> b 1 d')
746
+
747
+ is_single_token = seq.shape[1] == 1
748
+
798
749
  if not exists(state):
799
750
  state = (0, None, None, None, None)
800
751
 
801
752
  seq_index, weights, cache_store_seq, past_state, updates = state
802
753
 
803
- assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
804
-
805
754
  # store
806
755
 
807
756
  store_seq = default(store_seq, seq)
808
757
 
758
+ # take care of cache
759
+
760
+ if exists(cache_store_seq):
761
+ store_seq = safe_cat((cache_store_seq, store_seq))
762
+
809
763
  # functions
810
764
 
811
765
  # compute split sizes of sequence
@@ -861,6 +815,7 @@ class NeuralMemory(Module):
861
815
  weights,
862
816
  seq_index = seq_index,
863
817
  past_state = past_state,
818
+ prev_weights = prev_weights
864
819
  )
865
820
 
866
821
  seq_index = next_neural_mem_state.seq_index
@@ -883,9 +838,21 @@ class NeuralMemory(Module):
883
838
 
884
839
  # retrieve
885
840
 
841
+ need_pad = True
842
+ retrieve_chunk_size = None
843
+
844
+ if is_single_token:
845
+ retrieve_chunk_size = 1
846
+ need_pad = False
847
+
848
+ last_update, _ = past_state
849
+ updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
850
+
886
851
  retrieved = self.retrieve_memories(
887
852
  seq,
888
- updates
853
+ updates,
854
+ chunk_size = retrieve_chunk_size,
855
+ need_pad = need_pad,
889
856
  )
890
857
 
891
858
  return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=QhdSPgWntfWILMJ1t0xLKgvZfPZWu9vhzZWaesftaPg,24724
4
+ titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
+ titans_pytorch/neural_memory.py,sha256=pOoDhYrUgQ5EaUyiEvwDu7mfcaH6Jqqod5NwIFLbD9U,25798
6
+ titans_pytorch-0.2.16.dist-info/METADATA,sha256=TxyjTuJmP0o2NhrHmlzJCU3JivOA1rTY-xQp3Ir_igY,6812
7
+ titans_pytorch-0.2.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.16.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RfJ1SvQH5_4PmlB7g-13wPAqYtCCUJxfmtaL0oBrRCU,24563
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=BGAzdxxZoKo8SmuPpSx1h1Yexf0oRje5qsCjl5r4FSA,26833
6
- titans_pytorch-0.2.14.dist-info/METADATA,sha256=SE27_ln7ludwaO2J1uX7TJ99MbnCl4GLorYpbBIK920,6812
7
- titans_pytorch-0.2.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.2.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.2.14.dist-info/RECORD,,