titans-pytorch 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,10 +36,14 @@ class MemoryMLP(Module):
36
36
  def __init__(
37
37
  self,
38
38
  dim,
39
- depth
39
+ depth,
40
+ expansion_factor = 2.
40
41
  ):
41
42
  super().__init__()
42
- self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
43
+ dim_hidden = int(dim * expansion_factor)
44
+ dims = (dim, *((dim_hidden,) * (depth - 1)), dim)
45
+
46
+ self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
43
47
 
44
48
  self.ln = LayerNorm(dim)
45
49
 
@@ -299,7 +299,8 @@ class NeuralMemory(Module):
299
299
  accept_weight_residual = False,
300
300
  gated_transition = False,
301
301
  default_model_kwargs: dict = dict(
302
- depth = 2
302
+ depth = 2,
303
+ expansion_factor = 4.
303
304
  )
304
305
  ):
305
306
  super().__init__()
@@ -689,16 +690,27 @@ class NeuralMemory(Module):
689
690
  def retrieve_memories(
690
691
  self,
691
692
  seq,
692
- past_weights: dict[str, Tensor],
693
- chunk_size = None,
694
- need_pad = True
693
+ weights: dict[str, Tensor],
695
694
  ):
696
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
695
+ chunk_size = self.retrieve_chunk_size
696
+
697
+ weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
698
+
697
699
  batch, seq_len = seq.shape[:2]
698
700
 
699
- seq = self.retrieve_norm(seq)
701
+ # auto infer single token decoding, if there are only 1 set of weights and 1 token
702
+
703
+ is_one_token = seq_len == 1
704
+ is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
705
+
706
+ is_single_token_decode = is_one_token and is_one_weight
707
+
708
+ if is_single_token_decode:
709
+ chunk_size = 1
710
+
711
+ # padding related, for chunked processing
700
712
 
701
- need_pad = need_pad or chunk_size > 1
713
+ need_pad = chunk_size > 1 or not is_one_weight
702
714
 
703
715
  if need_pad:
704
716
  seq = pad_at_dim(seq, (1, 0), dim = 1)
@@ -713,7 +725,11 @@ class NeuralMemory(Module):
713
725
  # the parameters of the memory model stores the memories of the key / values
714
726
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
715
727
 
716
- curr_weights = TensorDict(past_weights)
728
+ weights = TensorDict(weights)
729
+
730
+ # pre norm
731
+
732
+ seq = self.retrieve_norm(seq)
717
733
 
718
734
  # sequence Float['b n d'] to queries
719
735
 
@@ -729,14 +745,14 @@ class NeuralMemory(Module):
729
745
 
730
746
  # fetch values from memory model
731
747
 
732
- if dict_get_shape(curr_weights) != self.init_weight_shape:
733
- curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
748
+ if weights_have_expanded_shape:
749
+ weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
734
750
 
735
751
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
736
752
 
737
753
  # forward functional call
738
754
 
739
- values = functional_call(self.memory_model, dict(curr_weights), queries)
755
+ values = functional_call(self.memory_model, dict(weights), queries)
740
756
 
741
757
  # reconstitute batch dimension
742
758
 
@@ -884,22 +900,13 @@ class NeuralMemory(Module):
884
900
 
885
901
  # retrieve
886
902
 
887
- need_pad = True
888
- retrieve_chunk_size = None
889
-
890
903
  if is_single_token:
891
- retrieve_chunk_size = 1
892
- need_pad = False
893
-
894
904
  last_update, _ = next_neural_mem_state.states
895
-
896
905
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
897
906
 
898
907
  retrieved = self.retrieve_memories(
899
908
  seq,
900
- updates,
901
- chunk_size = retrieve_chunk_size,
902
- need_pad = need_pad,
909
+ updates
903
910
  )
904
911
 
905
912
  return retrieved, next_neural_mem_state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
5
+ titans_pytorch/neural_memory.py,sha256=Ff-IBv-CCQAP7IYIpokPDoGtsvpzotAJsHB1d_-xd98,27934
6
+ titans_pytorch-0.3.3.dist-info/METADATA,sha256=CutjohW8xSNycd5W-uyXC4827ubmIpAJCs9xoMbfZzo,6815
7
+ titans_pytorch-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.3.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=K1z7wtv366Y6-eEyXMFZ_j7D2frPl5RxfSgxzFYoFMc,27704
6
- titans_pytorch-0.3.1.dist-info/METADATA,sha256=ZAxucKq2DZBtW-BI_O2sUQ5RXy11a7eu48yPpwnanpw,6815
7
- titans_pytorch-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.1.dist-info/RECORD,,