titans-pytorch 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -90,6 +90,9 @@ def divisible_by(num, den):
90
90
  def round_up_multiple(seq, mult):
91
91
  return ceil(seq / mult) * mult
92
92
 
93
+ def round_down_multiple(seq, mult):
94
+ return seq // mult * mult
95
+
93
96
  def pack_with_inverse(t, pattern):
94
97
  packed, packed_shape = pack(t, pattern)
95
98
 
@@ -116,11 +119,11 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
116
119
  if fold_into_batch:
117
120
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
118
121
 
119
- def inverse(out):
122
+ def inverse(out, remove_pad = True):
120
123
  if fold_into_batch:
121
124
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
122
125
 
123
- if needs_pad:
126
+ if needs_pad and remove_pad:
124
127
  out = out[..., :-padding, :]
125
128
 
126
129
  return out
@@ -312,7 +315,7 @@ class SegmentedAttention(Module):
312
315
 
313
316
  # caching
314
317
 
315
- next_cache = tuple(map(inverse_segment, (k, v)))
318
+ next_cache = (k, v)
316
319
 
317
320
  # take care of persistent memory key / values
318
321
 
@@ -575,6 +578,27 @@ class MemoryAsContextTransformer(Module):
575
578
 
576
579
  self.num_persist_mem_tokens = num_persist_mem_tokens
577
580
 
581
+ def seq_index_is_longterm(
582
+ self,
583
+ seq_index
584
+ ):
585
+ total_segment_len = self.attn_window_size
586
+
587
+ seq = seq_index + 1
588
+ seq -= int((seq % total_segment_len) == 0)
589
+ last_segment_len = round_down_multiple(seq, total_segment_len)
590
+ segment_seq = seq - last_segment_len
591
+ return (segment_seq - self.segment_len) > 0
592
+
593
+ def seq_len_with_longterm_mem(
594
+ self,
595
+ seq_len
596
+ ):
597
+ assert seq_len > 0
598
+
599
+ segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
600
+ return ceil(seq_len / segment_len) * num_mem + seq_len
601
+
578
602
  @torch.no_grad()
579
603
  def sample(
580
604
  self,
@@ -594,8 +618,6 @@ class MemoryAsContextTransformer(Module):
594
618
  prompt_seq_len, out = prompt.shape[-1], prompt.clone()
595
619
  sample_num_times = max(0, seq_len - prompt_seq_len)
596
620
 
597
- iter_wrap = tqdm.tqdm if show_progress else identity
598
-
599
621
  # cache for axial pos, attention, and neural memory
600
622
 
601
623
  cache = None
@@ -604,9 +626,7 @@ class MemoryAsContextTransformer(Module):
604
626
  # precompute factorized pos emb
605
627
 
606
628
  if use_cache:
607
- round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
608
- longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
609
- seq_len_with_mem = round_up_seq_len + longterm_mem_lens
629
+ seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
610
630
 
611
631
  axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
612
632
 
@@ -614,25 +634,31 @@ class MemoryAsContextTransformer(Module):
614
634
 
615
635
  # sample
616
636
 
617
- for _ in iter_wrap(range(sample_num_times)):
637
+ with tqdm.tqdm(total = sample_num_times, disable = not show_progress) as pbar:
618
638
 
619
- logits, next_cache = self.forward(
620
- out,
621
- disable_flex_attn = True,
622
- cache = cache,
623
- return_cache = True,
624
- factorized_pos_emb = factorized_pos_emb
625
- )
639
+ while out.shape[-1] < seq_len:
640
+
641
+ logits, next_cache = self.forward(
642
+ out,
643
+ disable_flex_attn = True,
644
+ cache = cache,
645
+ return_cache = True,
646
+ factorized_pos_emb = factorized_pos_emb
647
+ )
626
648
 
627
- if use_cache:
628
- cache = next_cache
649
+ if use_cache:
650
+ cache = next_cache
629
651
 
630
- logits = logits[:, -1]
652
+ if not exists(logits):
653
+ continue
631
654
 
632
- logits = filter_fn(logits, **filter_kwargs)
633
- sample = gumbel_sample(logits, temperature = temperature)
655
+ logits = logits[:, -1]
634
656
 
635
- out = torch.cat((out, sample), dim = -1)
657
+ logits = filter_fn(logits, **filter_kwargs)
658
+ sample = gumbel_sample(logits, temperature = temperature)
659
+
660
+ out = torch.cat((out, sample), dim = -1)
661
+ pbar.update(1)
636
662
 
637
663
  self.train(was_training)
638
664
 
@@ -656,6 +682,8 @@ class MemoryAsContextTransformer(Module):
656
682
 
657
683
  batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
658
684
 
685
+ seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
686
+
659
687
  # token embedding
660
688
 
661
689
  x = self.token_emb(x)
@@ -667,9 +695,11 @@ class MemoryAsContextTransformer(Module):
667
695
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
668
696
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
669
697
 
670
- x = inverse_segment(x)
698
+ x = inverse_segment(x, remove_pad = False)
699
+
700
+ # splice out unneeded tokens from padding for longterm mems
671
701
 
672
- seq_len_with_mem = x.shape[-2]
702
+ x = x[:, :seq_len_with_mem]
673
703
 
674
704
  # apply axial positional embedding
675
705
  # so intra and inter segment can be more easily discerned by the network
@@ -685,13 +715,12 @@ class MemoryAsContextTransformer(Module):
685
715
  flex_attn_fn = None
686
716
 
687
717
  if use_flex_attn:
688
- block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
718
+ block_mask = create_mac_block_mask(seq_len_with_mem, self.attn_window_size, self.num_persist_mem_tokens, self.sliding_window_attn)
689
719
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
690
720
 
691
721
  # kv caching
692
722
 
693
723
  is_inferencing = exists(cache)
694
- assert not (is_inferencing and self.num_longterm_mem_tokens > 0)
695
724
 
696
725
  if not exists(cache):
697
726
  cache = (None, None)
@@ -741,10 +770,10 @@ class MemoryAsContextTransformer(Module):
741
770
 
742
771
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
743
772
 
773
+ next_neural_mem_cache = (seq_len, None, None, None)
744
774
  else:
745
775
  retrieved, next_neural_mem_cache = mem.forward_inference(
746
776
  mem_input,
747
- seq_index = seq_len - 1,
748
777
  state = next(neural_mem_caches, None)
749
778
  )
750
779
 
@@ -775,15 +804,34 @@ class MemoryAsContextTransformer(Module):
775
804
 
776
805
  x = ff(x)
777
806
 
807
+ # taking care of cache first
808
+ # for early return when processing long term mem tokens during inference
809
+
810
+ if return_cache:
811
+ next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
812
+
813
+ # handle kv cache length depending on local attention type
814
+
815
+ next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
816
+
817
+ if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
818
+ next_kv_caches = next_kv_caches[..., 0:0, :]
819
+
820
+ # hyper connection reducing of streams
821
+
778
822
  x = self.reduce_streams(x)
779
823
 
780
824
  # excise out the memories
781
825
 
782
- x, inverse_segment = pad_and_segment_with_inverse(x, segment_len + num_longterm_mem_tokens)
826
+ if not is_inferencing:
827
+
828
+ x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
783
829
 
784
- x, _ = inverse_pack_mems(x)
830
+ x, _ = inverse_pack_mems(x)
785
831
 
786
- x = inverse_segment(x)
832
+ x = inverse_segment(x)
833
+
834
+ x = x[:, :seq_len]
787
835
 
788
836
  # to logits
789
837
 
@@ -795,15 +843,6 @@ class MemoryAsContextTransformer(Module):
795
843
  if not return_cache:
796
844
  return logits
797
845
 
798
- next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
799
-
800
- # handle kv cache length depending on local attention type
801
-
802
- next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
803
-
804
- if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
805
- next_kv_caches = next_kv_caches[..., 0:0, :]
806
-
807
846
  return logits, (next_kv_caches, next_neural_mem_caches)
808
847
 
809
848
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
titans_pytorch/titans.py CHANGED
@@ -783,18 +783,16 @@ class NeuralMemory(Module):
783
783
  def forward_inference(
784
784
  self,
785
785
  token: Tensor,
786
- seq_index = None, # the index of the token in the sequence, starts at 0
787
786
  state = None,
788
787
  ):
789
788
 
790
789
  # unpack previous state
791
790
 
792
791
  if not exists(state):
793
- state = (None, None, None)
792
+ state = (0, None, None, None)
794
793
 
795
- cache_store_seq, past_states, updates = state
794
+ seq_index, cache_store_seq, past_states, updates = state
796
795
 
797
- seq_index = default(seq_index, 0)
798
796
  curr_seq_len = seq_index + 1
799
797
  batch = token.shape[0]
800
798
 
@@ -814,7 +812,7 @@ class NeuralMemory(Module):
814
812
  if curr_seq_len < self.chunk_size:
815
813
  empty_mem = self.init_empty_memory_embed(batch, 1)
816
814
 
817
- return empty_mem, (cache_store_seq, past_states, updates)
815
+ return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
818
816
 
819
817
  # store if storage sequence cache hits the chunk size
820
818
 
@@ -842,7 +840,7 @@ class NeuralMemory(Module):
842
840
 
843
841
  # next state tuple
844
842
 
845
- next_state = (cache_store_seq, next_states, updates)
843
+ next_state = (curr_seq_len, cache_store_seq, next_states, updates)
846
844
 
847
845
  return retrieved, next_state
848
846
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.26
3
+ Version: 0.1.28
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.9
38
+ Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.9
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
4
+ titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
5
+ titans_pytorch-0.1.28.dist-info/METADATA,sha256=8AJX9oaut11GeFcyBmVsmbnY7oWhsal13yv75DtPeno,6815
6
+ titans_pytorch-0.1.28.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.28.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.28.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RkEGmVlQyK1opqylqt1VEFEc_Gd_pbAArcwfhphotXI,23564
4
- titans_pytorch/titans.py,sha256=a-BXTG6DdNXWhby6E4W2fdhwipuMQ12tSqSL10iLvfY,25826
5
- titans_pytorch-0.1.26.dist-info/METADATA,sha256=zogTDD7iLlxkPDzIeCap9GCgz2VNFUWjVF_K6K8H9kg,6814
6
- titans_pytorch-0.1.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.26.dist-info/RECORD,,