titans-pytorch 0.1.29__py3-none-any.whl → 0.1.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +20 -13
- titans_pytorch/titans.py +3 -2
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.31.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.31.dist-info/RECORD +8 -0
- titans_pytorch-0.1.29.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.31.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.29.dist-info → titans_pytorch-0.1.31.dist-info}/licenses/LICENSE +0 -0
@@ -582,13 +582,8 @@ class MemoryAsContextTransformer(Module):
|
|
582
582
|
self,
|
583
583
|
seq_index
|
584
584
|
):
|
585
|
-
total_segment_len = self.attn_window_size
|
586
|
-
|
587
|
-
seq = seq_index + 1
|
588
|
-
seq -= int((seq % total_segment_len) == 0)
|
589
|
-
last_segment_len = round_down_multiple(seq, total_segment_len)
|
590
|
-
segment_seq = seq - last_segment_len
|
591
|
-
return (segment_seq - self.segment_len) > 0
|
585
|
+
total_segment_len, segment_len = self.attn_window_size, self.segment_len
|
586
|
+
return ((seq_index % total_segment_len + 1) - segment_len) > 0
|
592
587
|
|
593
588
|
def seq_len_with_longterm_mem(
|
594
589
|
self,
|
@@ -597,7 +592,7 @@ class MemoryAsContextTransformer(Module):
|
|
597
592
|
assert seq_len > 0
|
598
593
|
|
599
594
|
segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
|
600
|
-
return
|
595
|
+
return ((seq_len - 1) // segment_len) * num_mem + seq_len
|
601
596
|
|
602
597
|
@torch.no_grad()
|
603
598
|
def sample(
|
@@ -723,9 +718,9 @@ class MemoryAsContextTransformer(Module):
|
|
723
718
|
is_inferencing = exists(cache)
|
724
719
|
|
725
720
|
if not exists(cache):
|
726
|
-
cache = (None, None)
|
721
|
+
cache = (seq_len_with_mem - 1, None, None)
|
727
722
|
|
728
|
-
kv_caches, neural_mem_caches = cache
|
723
|
+
inference_seq_index, kv_caches, neural_mem_caches = cache
|
729
724
|
|
730
725
|
kv_caches = iter(default(kv_caches, []))
|
731
726
|
neural_mem_caches = iter(default(neural_mem_caches, []))
|
@@ -744,7 +739,8 @@ class MemoryAsContextTransformer(Module):
|
|
744
739
|
# when inferencing, only do one token at a time
|
745
740
|
|
746
741
|
if is_inferencing:
|
747
|
-
|
742
|
+
ind = inference_seq_index
|
743
|
+
x = x[:, ind:(ind + 1)]
|
748
744
|
|
749
745
|
# expand and reduce streams for hyper connections
|
750
746
|
|
@@ -817,6 +813,17 @@ class MemoryAsContextTransformer(Module):
|
|
817
813
|
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
818
814
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
819
815
|
|
816
|
+
next_cache = (
|
817
|
+
inference_seq_index + 1,
|
818
|
+
next_kv_caches,
|
819
|
+
next_neural_mem_caches
|
820
|
+
)
|
821
|
+
|
822
|
+
is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
|
823
|
+
|
824
|
+
if is_inferencing and is_longterm_mem:
|
825
|
+
return None, next_cache
|
826
|
+
|
820
827
|
# hyper connection reducing of streams
|
821
828
|
|
822
829
|
x = self.reduce_streams(x)
|
@@ -829,7 +836,7 @@ class MemoryAsContextTransformer(Module):
|
|
829
836
|
|
830
837
|
x, _ = inverse_pack_mems(x)
|
831
838
|
|
832
|
-
x = inverse_segment(x)
|
839
|
+
x = inverse_segment(x, remove_pad = False)
|
833
840
|
|
834
841
|
x = x[:, :seq_len]
|
835
842
|
|
@@ -843,7 +850,7 @@ class MemoryAsContextTransformer(Module):
|
|
843
850
|
if not return_cache:
|
844
851
|
return logits
|
845
852
|
|
846
|
-
return logits,
|
853
|
+
return logits, next_cache
|
847
854
|
|
848
855
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
849
856
|
|
titans_pytorch/titans.py
CHANGED
@@ -409,6 +409,7 @@ class NeuralMemory(Module):
|
|
409
409
|
):
|
410
410
|
super().__init__()
|
411
411
|
dim_head = default(dim_head, dim)
|
412
|
+
assert not (heads == 1 and dim_head != dim)
|
412
413
|
|
413
414
|
self.retrieve_chunk_size, self.store_chunk_size = pair(chunk_size)
|
414
415
|
|
@@ -566,7 +567,7 @@ class NeuralMemory(Module):
|
|
566
567
|
):
|
567
568
|
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
568
569
|
|
569
|
-
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
570
|
+
seq_len, heads, chunk_size = seq.shape[-2], self.heads, default(chunk_size, self.store_chunk_size)
|
570
571
|
|
571
572
|
# handle edge case
|
572
573
|
|
@@ -645,7 +646,7 @@ class NeuralMemory(Module):
|
|
645
646
|
|
646
647
|
# restore batch and sequence dimension
|
647
648
|
|
648
|
-
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
649
|
+
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch * heads))
|
649
650
|
|
650
651
|
# maybe per layer modulation
|
651
652
|
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
|
4
|
+
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
+
titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
|
6
|
+
titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.31.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
-
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
-
titans_pytorch-0.1.29.dist-info/METADATA,sha256=9Na2UlBJ4mECXXY5GIyuokgN0oxs38rps24TIM6CNFY,6815
|
6
|
-
titans_pytorch-0.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.29.dist-info/RECORD,,
|
File without changes
|
File without changes
|