titans-pytorch 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +1 -1
- titans_pytorch/neural_memory.py +26 -20
- {titans_pytorch-0.3.2.dist-info → titans_pytorch-0.3.3.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.3.dist-info/RECORD +9 -0
- titans_pytorch-0.3.2.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.2.dist-info → titans_pytorch-0.3.3.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.2.dist-info → titans_pytorch-0.3.3.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
titans_pytorch/neural_memory.py
CHANGED
@@ -690,16 +690,27 @@ class NeuralMemory(Module):
|
|
690
690
|
def retrieve_memories(
|
691
691
|
self,
|
692
692
|
seq,
|
693
|
-
|
694
|
-
chunk_size = None,
|
695
|
-
need_pad = True
|
693
|
+
weights: dict[str, Tensor],
|
696
694
|
):
|
697
|
-
chunk_size =
|
695
|
+
chunk_size = self.retrieve_chunk_size
|
696
|
+
|
697
|
+
weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
|
698
|
+
|
698
699
|
batch, seq_len = seq.shape[:2]
|
699
700
|
|
700
|
-
|
701
|
+
# auto infer single token decoding, if there are only 1 set of weights and 1 token
|
702
|
+
|
703
|
+
is_one_token = seq_len == 1
|
704
|
+
is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
|
705
|
+
|
706
|
+
is_single_token_decode = is_one_token and is_one_weight
|
707
|
+
|
708
|
+
if is_single_token_decode:
|
709
|
+
chunk_size = 1
|
710
|
+
|
711
|
+
# padding related, for chunked processing
|
701
712
|
|
702
|
-
need_pad =
|
713
|
+
need_pad = chunk_size > 1 or not is_one_weight
|
703
714
|
|
704
715
|
if need_pad:
|
705
716
|
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
@@ -714,7 +725,11 @@ class NeuralMemory(Module):
|
|
714
725
|
# the parameters of the memory model stores the memories of the key / values
|
715
726
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
716
727
|
|
717
|
-
|
728
|
+
weights = TensorDict(weights)
|
729
|
+
|
730
|
+
# pre norm
|
731
|
+
|
732
|
+
seq = self.retrieve_norm(seq)
|
718
733
|
|
719
734
|
# sequence Float['b n d'] to queries
|
720
735
|
|
@@ -730,14 +745,14 @@ class NeuralMemory(Module):
|
|
730
745
|
|
731
746
|
# fetch values from memory model
|
732
747
|
|
733
|
-
if
|
734
|
-
|
748
|
+
if weights_have_expanded_shape:
|
749
|
+
weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
|
735
750
|
|
736
751
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
737
752
|
|
738
753
|
# forward functional call
|
739
754
|
|
740
|
-
values = functional_call(self.memory_model, dict(
|
755
|
+
values = functional_call(self.memory_model, dict(weights), queries)
|
741
756
|
|
742
757
|
# reconstitute batch dimension
|
743
758
|
|
@@ -885,22 +900,13 @@ class NeuralMemory(Module):
|
|
885
900
|
|
886
901
|
# retrieve
|
887
902
|
|
888
|
-
need_pad = True
|
889
|
-
retrieve_chunk_size = None
|
890
|
-
|
891
903
|
if is_single_token:
|
892
|
-
retrieve_chunk_size = 1
|
893
|
-
need_pad = False
|
894
|
-
|
895
904
|
last_update, _ = next_neural_mem_state.states
|
896
|
-
|
897
905
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
898
906
|
|
899
907
|
retrieved = self.retrieve_memories(
|
900
908
|
seq,
|
901
|
-
updates
|
902
|
-
chunk_size = retrieve_chunk_size,
|
903
|
-
need_pad = need_pad,
|
909
|
+
updates
|
904
910
|
)
|
905
911
|
|
906
912
|
return retrieved, next_neural_mem_state
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
|
5
|
+
titans_pytorch/neural_memory.py,sha256=Ff-IBv-CCQAP7IYIpokPDoGtsvpzotAJsHB1d_-xd98,27934
|
6
|
+
titans_pytorch-0.3.3.dist-info/METADATA,sha256=CutjohW8xSNycd5W-uyXC4827ubmIpAJCs9xoMbfZzo,6815
|
7
|
+
titans_pytorch-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.3.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=TJl7b9Rd5BP8aQXK8itap5YN3DyomUVxCRJDgPuRGBk,4843
|
5
|
-
titans_pytorch/neural_memory.py,sha256=QiEnHnZfQ8ptuXNVy4NZf9-XMbMOl2_1PT_YIG1GQBc,27739
|
6
|
-
titans_pytorch-0.3.2.dist-info/METADATA,sha256=Ar1OdcY09w-q3RlVKlxcgrtcVzZE6cRKqnjwQ4F-9Z8,6815
|
7
|
-
titans_pytorch-0.3.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|