titans-pytorch 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -37,7 +37,7 @@ class MemoryMLP(Module):
37
37
  self,
38
38
  dim,
39
39
  depth,
40
- expansion_factor = 4.
40
+ expansion_factor = 2.
41
41
  ):
42
42
  super().__init__()
43
43
  dim_hidden = int(dim * expansion_factor)
@@ -690,16 +690,27 @@ class NeuralMemory(Module):
690
690
  def retrieve_memories(
691
691
  self,
692
692
  seq,
693
- past_weights: dict[str, Tensor],
694
- chunk_size = None,
695
- need_pad = True
693
+ weights: dict[str, Tensor],
696
694
  ):
697
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
695
+ chunk_size = self.retrieve_chunk_size
696
+
697
+ weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
698
+
698
699
  batch, seq_len = seq.shape[:2]
699
700
 
700
- seq = self.retrieve_norm(seq)
701
+ # auto infer single token decoding, if there are only 1 set of weights and 1 token
702
+
703
+ is_one_token = seq_len == 1
704
+ is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
705
+
706
+ is_single_token_decode = is_one_token and is_one_weight
707
+
708
+ if is_single_token_decode:
709
+ chunk_size = 1
710
+
711
+ # padding related, for chunked processing
701
712
 
702
- need_pad = need_pad or chunk_size > 1
713
+ need_pad = chunk_size > 1 or not is_one_weight
703
714
 
704
715
  if need_pad:
705
716
  seq = pad_at_dim(seq, (1, 0), dim = 1)
@@ -714,7 +725,11 @@ class NeuralMemory(Module):
714
725
  # the parameters of the memory model stores the memories of the key / values
715
726
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
716
727
 
717
- curr_weights = TensorDict(past_weights)
728
+ weights = TensorDict(weights)
729
+
730
+ # pre norm
731
+
732
+ seq = self.retrieve_norm(seq)
718
733
 
719
734
  # sequence Float['b n d'] to queries
720
735
 
@@ -730,14 +745,14 @@ class NeuralMemory(Module):
730
745
 
731
746
  # fetch values from memory model
732
747
 
733
- if dict_get_shape(curr_weights) != self.init_weight_shape:
734
- curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
748
+ if weights_have_expanded_shape:
749
+ weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
735
750
 
736
751
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
737
752
 
738
753
  # forward functional call
739
754
 
740
- values = functional_call(self.memory_model, dict(curr_weights), queries)
755
+ values = functional_call(self.memory_model, dict(weights), queries)
741
756
 
742
757
  # reconstitute batch dimension
743
758
 
@@ -885,22 +900,13 @@ class NeuralMemory(Module):
885
900
 
886
901
  # retrieve
887
902
 
888
- need_pad = True
889
- retrieve_chunk_size = None
890
-
891
903
  if is_single_token:
892
- retrieve_chunk_size = 1
893
- need_pad = False
894
-
895
904
  last_update, _ = next_neural_mem_state.states
896
-
897
905
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
898
906
 
899
907
  retrieved = self.retrieve_memories(
900
908
  seq,
901
- updates,
902
- chunk_size = retrieve_chunk_size,
903
- need_pad = need_pad,
909
+ updates
904
910
  )
905
911
 
906
912
  return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
5
+ titans_pytorch/neural_memory.py,sha256=Ff-IBv-CCQAP7IYIpokPDoGtsvpzotAJsHB1d_-xd98,27934
6
+ titans_pytorch-0.3.3.dist-info/METADATA,sha256=CutjohW8xSNycd5W-uyXC4827ubmIpAJCs9xoMbfZzo,6815
7
+ titans_pytorch-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.3.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=TJl7b9Rd5BP8aQXK8itap5YN3DyomUVxCRJDgPuRGBk,4843
5
- titans_pytorch/neural_memory.py,sha256=QiEnHnZfQ8ptuXNVy4NZf9-XMbMOl2_1PT_YIG1GQBc,27739
6
- titans_pytorch-0.3.2.dist-info/METADATA,sha256=Ar1OdcY09w-q3RlVKlxcgrtcVzZE6cRKqnjwQ4F-9Z8,6815
7
- titans_pytorch-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.2.dist-info/RECORD,,