titans-pytorch 0.1.30__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +20 -13
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.31.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.31.dist-info/RECORD +8 -0
- titans_pytorch-0.1.30.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.31.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.31.dist-info}/licenses/LICENSE +0 -0
@@ -582,13 +582,8 @@ class MemoryAsContextTransformer(Module):
|
|
582
582
|
self,
|
583
583
|
seq_index
|
584
584
|
):
|
585
|
-
total_segment_len = self.attn_window_size
|
586
|
-
|
587
|
-
seq = seq_index + 1
|
588
|
-
seq -= int((seq % total_segment_len) == 0)
|
589
|
-
last_segment_len = round_down_multiple(seq, total_segment_len)
|
590
|
-
segment_seq = seq - last_segment_len
|
591
|
-
return (segment_seq - self.segment_len) > 0
|
585
|
+
total_segment_len, segment_len = self.attn_window_size, self.segment_len
|
586
|
+
return ((seq_index % total_segment_len + 1) - segment_len) > 0
|
592
587
|
|
593
588
|
def seq_len_with_longterm_mem(
|
594
589
|
self,
|
@@ -597,7 +592,7 @@ class MemoryAsContextTransformer(Module):
|
|
597
592
|
assert seq_len > 0
|
598
593
|
|
599
594
|
segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
|
600
|
-
return
|
595
|
+
return ((seq_len - 1) // segment_len) * num_mem + seq_len
|
601
596
|
|
602
597
|
@torch.no_grad()
|
603
598
|
def sample(
|
@@ -723,9 +718,9 @@ class MemoryAsContextTransformer(Module):
|
|
723
718
|
is_inferencing = exists(cache)
|
724
719
|
|
725
720
|
if not exists(cache):
|
726
|
-
cache = (None, None)
|
721
|
+
cache = (seq_len_with_mem - 1, None, None)
|
727
722
|
|
728
|
-
kv_caches, neural_mem_caches = cache
|
723
|
+
inference_seq_index, kv_caches, neural_mem_caches = cache
|
729
724
|
|
730
725
|
kv_caches = iter(default(kv_caches, []))
|
731
726
|
neural_mem_caches = iter(default(neural_mem_caches, []))
|
@@ -744,7 +739,8 @@ class MemoryAsContextTransformer(Module):
|
|
744
739
|
# when inferencing, only do one token at a time
|
745
740
|
|
746
741
|
if is_inferencing:
|
747
|
-
|
742
|
+
ind = inference_seq_index
|
743
|
+
x = x[:, ind:(ind + 1)]
|
748
744
|
|
749
745
|
# expand and reduce streams for hyper connections
|
750
746
|
|
@@ -817,6 +813,17 @@ class MemoryAsContextTransformer(Module):
|
|
817
813
|
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
818
814
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
819
815
|
|
816
|
+
next_cache = (
|
817
|
+
inference_seq_index + 1,
|
818
|
+
next_kv_caches,
|
819
|
+
next_neural_mem_caches
|
820
|
+
)
|
821
|
+
|
822
|
+
is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
|
823
|
+
|
824
|
+
if is_inferencing and is_longterm_mem:
|
825
|
+
return None, next_cache
|
826
|
+
|
820
827
|
# hyper connection reducing of streams
|
821
828
|
|
822
829
|
x = self.reduce_streams(x)
|
@@ -829,7 +836,7 @@ class MemoryAsContextTransformer(Module):
|
|
829
836
|
|
830
837
|
x, _ = inverse_pack_mems(x)
|
831
838
|
|
832
|
-
x = inverse_segment(x)
|
839
|
+
x = inverse_segment(x, remove_pad = False)
|
833
840
|
|
834
841
|
x = x[:, :seq_len]
|
835
842
|
|
@@ -843,7 +850,7 @@ class MemoryAsContextTransformer(Module):
|
|
843
850
|
if not return_cache:
|
844
851
|
return logits
|
845
852
|
|
846
|
-
return logits,
|
853
|
+
return logits, next_cache
|
847
854
|
|
848
855
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
849
856
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
|
4
|
+
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
+
titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
|
6
|
+
titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.31.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
-
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
-
titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
|
6
|
-
titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|