titans-pytorch 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +6 -2
- titans_pytorch/neural_memory.py +28 -21
- {titans_pytorch-0.3.1.dist-info → titans_pytorch-0.3.3.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.3.dist-info/RECORD +9 -0
- titans_pytorch-0.3.1.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.1.dist-info → titans_pytorch-0.3.3.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.1.dist-info → titans_pytorch-0.3.3.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -36,10 +36,14 @@ class MemoryMLP(Module):
|
|
36
36
|
def __init__(
|
37
37
|
self,
|
38
38
|
dim,
|
39
|
-
depth
|
39
|
+
depth,
|
40
|
+
expansion_factor = 2.
|
40
41
|
):
|
41
42
|
super().__init__()
|
42
|
-
|
43
|
+
dim_hidden = int(dim * expansion_factor)
|
44
|
+
dims = (dim, *((dim_hidden,) * (depth - 1)), dim)
|
45
|
+
|
46
|
+
self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
|
43
47
|
|
44
48
|
self.ln = LayerNorm(dim)
|
45
49
|
|
titans_pytorch/neural_memory.py
CHANGED
@@ -299,7 +299,8 @@ class NeuralMemory(Module):
|
|
299
299
|
accept_weight_residual = False,
|
300
300
|
gated_transition = False,
|
301
301
|
default_model_kwargs: dict = dict(
|
302
|
-
depth = 2
|
302
|
+
depth = 2,
|
303
|
+
expansion_factor = 4.
|
303
304
|
)
|
304
305
|
):
|
305
306
|
super().__init__()
|
@@ -689,16 +690,27 @@ class NeuralMemory(Module):
|
|
689
690
|
def retrieve_memories(
|
690
691
|
self,
|
691
692
|
seq,
|
692
|
-
|
693
|
-
chunk_size = None,
|
694
|
-
need_pad = True
|
693
|
+
weights: dict[str, Tensor],
|
695
694
|
):
|
696
|
-
chunk_size =
|
695
|
+
chunk_size = self.retrieve_chunk_size
|
696
|
+
|
697
|
+
weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
|
698
|
+
|
697
699
|
batch, seq_len = seq.shape[:2]
|
698
700
|
|
699
|
-
|
701
|
+
# auto infer single token decoding, if there are only 1 set of weights and 1 token
|
702
|
+
|
703
|
+
is_one_token = seq_len == 1
|
704
|
+
is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
|
705
|
+
|
706
|
+
is_single_token_decode = is_one_token and is_one_weight
|
707
|
+
|
708
|
+
if is_single_token_decode:
|
709
|
+
chunk_size = 1
|
710
|
+
|
711
|
+
# padding related, for chunked processing
|
700
712
|
|
701
|
-
need_pad =
|
713
|
+
need_pad = chunk_size > 1 or not is_one_weight
|
702
714
|
|
703
715
|
if need_pad:
|
704
716
|
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
@@ -713,7 +725,11 @@ class NeuralMemory(Module):
|
|
713
725
|
# the parameters of the memory model stores the memories of the key / values
|
714
726
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
715
727
|
|
716
|
-
|
728
|
+
weights = TensorDict(weights)
|
729
|
+
|
730
|
+
# pre norm
|
731
|
+
|
732
|
+
seq = self.retrieve_norm(seq)
|
717
733
|
|
718
734
|
# sequence Float['b n d'] to queries
|
719
735
|
|
@@ -729,14 +745,14 @@ class NeuralMemory(Module):
|
|
729
745
|
|
730
746
|
# fetch values from memory model
|
731
747
|
|
732
|
-
if
|
733
|
-
|
748
|
+
if weights_have_expanded_shape:
|
749
|
+
weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
|
734
750
|
|
735
751
|
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
736
752
|
|
737
753
|
# forward functional call
|
738
754
|
|
739
|
-
values = functional_call(self.memory_model, dict(
|
755
|
+
values = functional_call(self.memory_model, dict(weights), queries)
|
740
756
|
|
741
757
|
# reconstitute batch dimension
|
742
758
|
|
@@ -884,22 +900,13 @@ class NeuralMemory(Module):
|
|
884
900
|
|
885
901
|
# retrieve
|
886
902
|
|
887
|
-
need_pad = True
|
888
|
-
retrieve_chunk_size = None
|
889
|
-
|
890
903
|
if is_single_token:
|
891
|
-
retrieve_chunk_size = 1
|
892
|
-
need_pad = False
|
893
|
-
|
894
904
|
last_update, _ = next_neural_mem_state.states
|
895
|
-
|
896
905
|
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
897
906
|
|
898
907
|
retrieved = self.retrieve_memories(
|
899
908
|
seq,
|
900
|
-
updates
|
901
|
-
chunk_size = retrieve_chunk_size,
|
902
|
-
need_pad = need_pad,
|
909
|
+
updates
|
903
910
|
)
|
904
911
|
|
905
912
|
return retrieved, next_neural_mem_state
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
|
5
|
+
titans_pytorch/neural_memory.py,sha256=Ff-IBv-CCQAP7IYIpokPDoGtsvpzotAJsHB1d_-xd98,27934
|
6
|
+
titans_pytorch-0.3.3.dist-info/METADATA,sha256=CutjohW8xSNycd5W-uyXC4827ubmIpAJCs9xoMbfZzo,6815
|
7
|
+
titans_pytorch-0.3.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.3.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=K1z7wtv366Y6-eEyXMFZ_j7D2frPl5RxfSgxzFYoFMc,27704
|
6
|
-
titans_pytorch-0.3.1.dist-info/METADATA,sha256=ZAxucKq2DZBtW-BI_O2sUQ5RXy11a7eu48yPpwnanpw,6815
|
7
|
-
titans_pytorch-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|