titans-pytorch 0.2.14__py3-none-any.whl → 0.2.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +4 -10
- titans_pytorch/neural_memory.py +33 -75
- {titans_pytorch-0.2.14.dist-info → titans_pytorch-0.2.15.dist-info}/METADATA +1 -1
- titans_pytorch-0.2.15.dist-info/RECORD +9 -0
- titans_pytorch-0.2.14.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.14.dist-info → titans_pytorch-0.2.15.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.14.dist-info → titans_pytorch-0.2.15.dist-info}/licenses/LICENSE +0 -0
@@ -761,16 +761,10 @@ class MemoryAsContextTransformer(Module):
|
|
761
761
|
|
762
762
|
mem_input, add_residual = mem_hyper_conn(x)
|
763
763
|
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
else:
|
770
|
-
(retrieved, next_neural_mem_cache) = mem.forward_inference(
|
771
|
-
mem_input,
|
772
|
-
state = next(neural_mem_caches, None),
|
773
|
-
)
|
764
|
+
retrieved, next_neural_mem_cache = mem.forward(
|
765
|
+
mem_input,
|
766
|
+
state = next(neural_mem_caches, None),
|
767
|
+
)
|
774
768
|
|
775
769
|
if self.gate_attn_output:
|
776
770
|
attn_out_gates = retrieved.sigmoid()
|
titans_pytorch/neural_memory.py
CHANGED
@@ -655,15 +655,19 @@ class NeuralMemory(Module):
|
|
655
655
|
self,
|
656
656
|
seq,
|
657
657
|
past_weights: dict[str, Tensor],
|
658
|
+
chunk_size = None,
|
659
|
+
need_pad = True
|
658
660
|
):
|
659
|
-
chunk_size = self.retrieve_chunk_size
|
661
|
+
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
660
662
|
batch, seq_len = seq.shape[:2]
|
661
663
|
|
662
664
|
seq = self.retrieve_norm(seq)
|
663
665
|
|
664
|
-
|
666
|
+
need_pad = need_pad or chunk_size > 1
|
667
|
+
|
668
|
+
if need_pad:
|
669
|
+
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
665
670
|
|
666
|
-
seq = pad_at_dim(seq, (1, 0), dim = 1)
|
667
671
|
seq_len_plus_one = seq.shape[-2]
|
668
672
|
|
669
673
|
next_seq_len = round_up_multiple(seq_len_plus_one, chunk_size)
|
@@ -718,76 +722,10 @@ class NeuralMemory(Module):
|
|
718
722
|
|
719
723
|
# restore, pad with empty memory embed
|
720
724
|
|
721
|
-
|
722
|
-
|
723
|
-
return values
|
724
|
-
|
725
|
-
@torch.no_grad()
|
726
|
-
def forward_inference(
|
727
|
-
self,
|
728
|
-
token: Tensor,
|
729
|
-
state: NeuralMemCache | None = None,
|
730
|
-
):
|
731
|
-
# unpack previous state
|
732
|
-
|
733
|
-
if not exists(state):
|
734
|
-
state = (0, None, None, None, None)
|
735
|
-
|
736
|
-
seq_index, weights, cache_store_seq, past_states, updates = state
|
737
|
-
|
738
|
-
curr_seq_len = seq_index + 1
|
739
|
-
batch = token.shape[0]
|
725
|
+
if need_pad:
|
726
|
+
values = values[:, 1:]
|
740
727
|
|
741
|
-
|
742
|
-
token = rearrange(token, 'b d -> b 1 d')
|
743
|
-
|
744
|
-
assert token.shape[1] == 1
|
745
|
-
|
746
|
-
# increment the sequence cache which is at most the chunk size
|
747
|
-
|
748
|
-
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
749
|
-
|
750
|
-
# early return empty memory, when no memories are stored for steps < first chunk size
|
751
|
-
|
752
|
-
if curr_seq_len < self.chunk_size:
|
753
|
-
retrieve = self.retrieve_memories(token, weights, chunk_size = 1)
|
754
|
-
|
755
|
-
output = retrieve, NeuralMemCache(curr_seq_len, weights, cache_store_seq, past_states, updates)
|
756
|
-
|
757
|
-
return output
|
758
|
-
|
759
|
-
# store if storage sequence cache hits the chunk size
|
760
|
-
|
761
|
-
next_states = past_states
|
762
|
-
store_seq_cache_len = cache_store_seq.shape[-2]
|
763
|
-
|
764
|
-
if not exists(updates):
|
765
|
-
updates = weights.clone().zero_()
|
766
|
-
updates = repeat_dict_values(updates, '... -> b 1 ...', b = batch)
|
767
|
-
else:
|
768
|
-
updates = updates.apply(lambda t: t[:, -1:])
|
769
|
-
|
770
|
-
if store_seq_cache_len == self.chunk_size:
|
771
|
-
|
772
|
-
next_updates, store_state = self.store_memories(
|
773
|
-
cache_store_seq,
|
774
|
-
weights,
|
775
|
-
past_state = past_states,
|
776
|
-
)
|
777
|
-
|
778
|
-
updates = next_updates
|
779
|
-
cache_store_seq = None
|
780
|
-
next_states = store_state.states
|
781
|
-
|
782
|
-
# retrieve
|
783
|
-
|
784
|
-
retrieved = self.retrieve_memories(token, updates, chunk_size = 1)
|
785
|
-
|
786
|
-
# next state tuple
|
787
|
-
|
788
|
-
next_store_state = NeuralMemCache(curr_seq_len, weights, cache_store_seq, next_states, updates)
|
789
|
-
|
790
|
-
return retrieved, next_store_state
|
728
|
+
return values[:, :seq_len]
|
791
729
|
|
792
730
|
def forward(
|
793
731
|
self,
|
@@ -795,17 +733,25 @@ class NeuralMemory(Module):
|
|
795
733
|
store_seq = None,
|
796
734
|
state: NeuralMemCache | None = None,
|
797
735
|
):
|
736
|
+
if seq.ndim == 2:
|
737
|
+
seq = rearrange(seq, 'b d -> b 1 d')
|
738
|
+
|
739
|
+
is_single_token = seq.shape[1] == 1
|
740
|
+
|
798
741
|
if not exists(state):
|
799
742
|
state = (0, None, None, None, None)
|
800
743
|
|
801
744
|
seq_index, weights, cache_store_seq, past_state, updates = state
|
802
745
|
|
803
|
-
assert not exists(cache_store_seq) or is_empty_tensor(cache_store_seq)
|
804
|
-
|
805
746
|
# store
|
806
747
|
|
807
748
|
store_seq = default(store_seq, seq)
|
808
749
|
|
750
|
+
# take care of cache
|
751
|
+
|
752
|
+
if exists(cache_store_seq):
|
753
|
+
store_seq = safe_cat((cache_store_seq, store_seq))
|
754
|
+
|
809
755
|
# functions
|
810
756
|
|
811
757
|
# compute split sizes of sequence
|
@@ -883,9 +829,21 @@ class NeuralMemory(Module):
|
|
883
829
|
|
884
830
|
# retrieve
|
885
831
|
|
832
|
+
need_pad = True
|
833
|
+
retrieve_chunk_size = None
|
834
|
+
|
835
|
+
if is_single_token:
|
836
|
+
retrieve_chunk_size = 1
|
837
|
+
need_pad = False
|
838
|
+
|
839
|
+
last_update, _ = past_state
|
840
|
+
updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
|
841
|
+
|
886
842
|
retrieved = self.retrieve_memories(
|
887
843
|
seq,
|
888
|
-
updates
|
844
|
+
updates,
|
845
|
+
chunk_size = retrieve_chunk_size,
|
846
|
+
need_pad = need_pad,
|
889
847
|
)
|
890
848
|
|
891
849
|
return retrieved, next_neural_mem_state
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Udu-9mtPy9sDeDyXKo95YMel3ELv5quJXINW-JG-hdk,24357
|
4
|
+
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
+
titans_pytorch/neural_memory.py,sha256=iu9lnRrqWmtFw3QYyJlS7mOP2zI2HJFuhs3TyfkKV3o,25482
|
6
|
+
titans_pytorch-0.2.15.dist-info/METADATA,sha256=vOb0Tt6-egnqtNXMfrJVibHwm8VuWQMlPw3C7Y_L4Wg,6812
|
7
|
+
titans_pytorch-0.2.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.15.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RfJ1SvQH5_4PmlB7g-13wPAqYtCCUJxfmtaL0oBrRCU,24563
|
4
|
-
titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
|
5
|
-
titans_pytorch/neural_memory.py,sha256=BGAzdxxZoKo8SmuPpSx1h1Yexf0oRje5qsCjl5r4FSA,26833
|
6
|
-
titans_pytorch-0.2.14.dist-info/METADATA,sha256=SE27_ln7ludwaO2J1uX7TJ99MbnCl4GLorYpbBIK920,6812
|
7
|
-
titans_pytorch-0.2.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|