titans-pytorch 0.1.27__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +77 -38
- {titans_pytorch-0.1.27.dist-info → titans_pytorch-0.1.29.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.29.dist-info/RECORD +8 -0
- titans_pytorch-0.1.27.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.27.dist-info → titans_pytorch-0.1.29.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.27.dist-info → titans_pytorch-0.1.29.dist-info}/licenses/LICENSE +0 -0
@@ -90,6 +90,9 @@ def divisible_by(num, den):
|
|
90
90
|
def round_up_multiple(seq, mult):
|
91
91
|
return ceil(seq / mult) * mult
|
92
92
|
|
93
|
+
def round_down_multiple(seq, mult):
|
94
|
+
return seq // mult * mult
|
95
|
+
|
93
96
|
def pack_with_inverse(t, pattern):
|
94
97
|
packed, packed_shape = pack(t, pattern)
|
95
98
|
|
@@ -116,11 +119,11 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
116
119
|
if fold_into_batch:
|
117
120
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
118
121
|
|
119
|
-
def inverse(out):
|
122
|
+
def inverse(out, remove_pad = True):
|
120
123
|
if fold_into_batch:
|
121
124
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
122
125
|
|
123
|
-
if needs_pad:
|
126
|
+
if needs_pad and remove_pad:
|
124
127
|
out = out[..., :-padding, :]
|
125
128
|
|
126
129
|
return out
|
@@ -312,7 +315,7 @@ class SegmentedAttention(Module):
|
|
312
315
|
|
313
316
|
# caching
|
314
317
|
|
315
|
-
next_cache =
|
318
|
+
next_cache = (k, v)
|
316
319
|
|
317
320
|
# take care of persistent memory key / values
|
318
321
|
|
@@ -575,6 +578,27 @@ class MemoryAsContextTransformer(Module):
|
|
575
578
|
|
576
579
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
577
580
|
|
581
|
+
def seq_index_is_longterm(
|
582
|
+
self,
|
583
|
+
seq_index
|
584
|
+
):
|
585
|
+
total_segment_len = self.attn_window_size
|
586
|
+
|
587
|
+
seq = seq_index + 1
|
588
|
+
seq -= int((seq % total_segment_len) == 0)
|
589
|
+
last_segment_len = round_down_multiple(seq, total_segment_len)
|
590
|
+
segment_seq = seq - last_segment_len
|
591
|
+
return (segment_seq - self.segment_len) > 0
|
592
|
+
|
593
|
+
def seq_len_with_longterm_mem(
|
594
|
+
self,
|
595
|
+
seq_len
|
596
|
+
):
|
597
|
+
assert seq_len > 0
|
598
|
+
|
599
|
+
segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
|
600
|
+
return ceil(seq_len / segment_len) * num_mem + seq_len
|
601
|
+
|
578
602
|
@torch.no_grad()
|
579
603
|
def sample(
|
580
604
|
self,
|
@@ -594,8 +618,6 @@ class MemoryAsContextTransformer(Module):
|
|
594
618
|
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
595
619
|
sample_num_times = max(0, seq_len - prompt_seq_len)
|
596
620
|
|
597
|
-
iter_wrap = tqdm.tqdm if show_progress else identity
|
598
|
-
|
599
621
|
# cache for axial pos, attention, and neural memory
|
600
622
|
|
601
623
|
cache = None
|
@@ -604,9 +626,7 @@ class MemoryAsContextTransformer(Module):
|
|
604
626
|
# precompute factorized pos emb
|
605
627
|
|
606
628
|
if use_cache:
|
607
|
-
|
608
|
-
longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
|
609
|
-
seq_len_with_mem = round_up_seq_len + longterm_mem_lens
|
629
|
+
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
610
630
|
|
611
631
|
axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
|
612
632
|
|
@@ -614,25 +634,31 @@ class MemoryAsContextTransformer(Module):
|
|
614
634
|
|
615
635
|
# sample
|
616
636
|
|
617
|
-
|
637
|
+
with tqdm.tqdm(total = sample_num_times, disable = not show_progress) as pbar:
|
618
638
|
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
639
|
+
while out.shape[-1] < seq_len:
|
640
|
+
|
641
|
+
logits, next_cache = self.forward(
|
642
|
+
out,
|
643
|
+
disable_flex_attn = True,
|
644
|
+
cache = cache,
|
645
|
+
return_cache = True,
|
646
|
+
factorized_pos_emb = factorized_pos_emb
|
647
|
+
)
|
626
648
|
|
627
|
-
|
628
|
-
|
649
|
+
if use_cache:
|
650
|
+
cache = next_cache
|
629
651
|
|
630
|
-
|
652
|
+
if not exists(logits):
|
653
|
+
continue
|
631
654
|
|
632
|
-
|
633
|
-
sample = gumbel_sample(logits, temperature = temperature)
|
655
|
+
logits = logits[:, -1]
|
634
656
|
|
635
|
-
|
657
|
+
logits = filter_fn(logits, **filter_kwargs)
|
658
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
659
|
+
|
660
|
+
out = torch.cat((out, sample), dim = -1)
|
661
|
+
pbar.update(1)
|
636
662
|
|
637
663
|
self.train(was_training)
|
638
664
|
|
@@ -656,6 +682,8 @@ class MemoryAsContextTransformer(Module):
|
|
656
682
|
|
657
683
|
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
658
684
|
|
685
|
+
seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
|
686
|
+
|
659
687
|
# token embedding
|
660
688
|
|
661
689
|
x = self.token_emb(x)
|
@@ -667,9 +695,11 @@ class MemoryAsContextTransformer(Module):
|
|
667
695
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
668
696
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
669
697
|
|
670
|
-
x = inverse_segment(x)
|
698
|
+
x = inverse_segment(x, remove_pad = False)
|
699
|
+
|
700
|
+
# splice out unneeded tokens from padding for longterm mems
|
671
701
|
|
672
|
-
|
702
|
+
x = x[:, :seq_len_with_mem]
|
673
703
|
|
674
704
|
# apply axial positional embedding
|
675
705
|
# so intra and inter segment can be more easily discerned by the network
|
@@ -685,13 +715,12 @@ class MemoryAsContextTransformer(Module):
|
|
685
715
|
flex_attn_fn = None
|
686
716
|
|
687
717
|
if use_flex_attn:
|
688
|
-
block_mask = create_mac_block_mask(seq_len_with_mem,
|
718
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, self.attn_window_size, self.num_persist_mem_tokens, self.sliding_window_attn)
|
689
719
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
690
720
|
|
691
721
|
# kv caching
|
692
722
|
|
693
723
|
is_inferencing = exists(cache)
|
694
|
-
assert not (is_inferencing and self.num_longterm_mem_tokens > 0)
|
695
724
|
|
696
725
|
if not exists(cache):
|
697
726
|
cache = (None, None)
|
@@ -775,15 +804,34 @@ class MemoryAsContextTransformer(Module):
|
|
775
804
|
|
776
805
|
x = ff(x)
|
777
806
|
|
807
|
+
# taking care of cache first
|
808
|
+
# for early return when processing long term mem tokens during inference
|
809
|
+
|
810
|
+
if return_cache:
|
811
|
+
next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
|
812
|
+
|
813
|
+
# handle kv cache length depending on local attention type
|
814
|
+
|
815
|
+
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
816
|
+
|
817
|
+
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
818
|
+
next_kv_caches = next_kv_caches[..., 0:0, :]
|
819
|
+
|
820
|
+
# hyper connection reducing of streams
|
821
|
+
|
778
822
|
x = self.reduce_streams(x)
|
779
823
|
|
780
824
|
# excise out the memories
|
781
825
|
|
782
|
-
|
826
|
+
if not is_inferencing:
|
827
|
+
|
828
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
|
783
829
|
|
784
|
-
|
830
|
+
x, _ = inverse_pack_mems(x)
|
785
831
|
|
786
|
-
|
832
|
+
x = inverse_segment(x)
|
833
|
+
|
834
|
+
x = x[:, :seq_len]
|
787
835
|
|
788
836
|
# to logits
|
789
837
|
|
@@ -795,15 +843,6 @@ class MemoryAsContextTransformer(Module):
|
|
795
843
|
if not return_cache:
|
796
844
|
return logits
|
797
845
|
|
798
|
-
next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
|
799
|
-
|
800
|
-
# handle kv cache length depending on local attention type
|
801
|
-
|
802
|
-
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
803
|
-
|
804
|
-
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
805
|
-
next_kv_caches = next_kv_caches[..., 0:0, :]
|
806
|
-
|
807
846
|
return logits, (next_kv_caches, next_neural_mem_caches)
|
808
847
|
|
809
848
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
+
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
+
titans_pytorch-0.1.29.dist-info/METADATA,sha256=9Na2UlBJ4mECXXY5GIyuokgN0oxs38rps24TIM6CNFY,6815
|
6
|
+
titans_pytorch-0.1.29.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.29.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.29.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Staf9hRQ44QAL23bSGh4VSB8NeGtMri-JdiZdgirJiU,23587
|
4
|
-
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
-
titans_pytorch-0.1.27.dist-info/METADATA,sha256=AZ5-_d9o_khm6jaky1zoKyXB1hDQNifbS061v_b4McQ,6815
|
6
|
-
titans_pytorch-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|