titans-pytorch 0.1.30__py3-none-any.whl → 0.1.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -582,13 +582,8 @@ class MemoryAsContextTransformer(Module):
582
582
  self,
583
583
  seq_index
584
584
  ):
585
- total_segment_len = self.attn_window_size
586
-
587
- seq = seq_index + 1
588
- seq -= int((seq % total_segment_len) == 0)
589
- last_segment_len = round_down_multiple(seq, total_segment_len)
590
- segment_seq = seq - last_segment_len
591
- return (segment_seq - self.segment_len) > 0
585
+ total_segment_len, segment_len = self.attn_window_size, self.segment_len
586
+ return ((seq_index % total_segment_len + 1) - segment_len) > 0
592
587
 
593
588
  def seq_len_with_longterm_mem(
594
589
  self,
@@ -597,7 +592,7 @@ class MemoryAsContextTransformer(Module):
597
592
  assert seq_len > 0
598
593
 
599
594
  segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
600
- return ceil(seq_len / segment_len) * num_mem + seq_len
595
+ return ((seq_len - 1) // segment_len) * num_mem + seq_len
601
596
 
602
597
  @torch.no_grad()
603
598
  def sample(
@@ -723,9 +718,9 @@ class MemoryAsContextTransformer(Module):
723
718
  is_inferencing = exists(cache)
724
719
 
725
720
  if not exists(cache):
726
- cache = (None, None)
721
+ cache = (seq_len_with_mem - 1, None, None)
727
722
 
728
- kv_caches, neural_mem_caches = cache
723
+ inference_seq_index, kv_caches, neural_mem_caches = cache
729
724
 
730
725
  kv_caches = iter(default(kv_caches, []))
731
726
  neural_mem_caches = iter(default(neural_mem_caches, []))
@@ -744,7 +739,8 @@ class MemoryAsContextTransformer(Module):
744
739
  # when inferencing, only do one token at a time
745
740
 
746
741
  if is_inferencing:
747
- x = x[:, -1:]
742
+ ind = inference_seq_index
743
+ x = x[:, ind:(ind + 1)]
748
744
 
749
745
  # expand and reduce streams for hyper connections
750
746
 
@@ -817,6 +813,17 @@ class MemoryAsContextTransformer(Module):
817
813
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
818
814
  next_kv_caches = next_kv_caches[..., 0:0, :]
819
815
 
816
+ next_cache = (
817
+ inference_seq_index + 1,
818
+ next_kv_caches,
819
+ next_neural_mem_caches
820
+ )
821
+
822
+ is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
823
+
824
+ if is_inferencing and is_longterm_mem:
825
+ return None, next_cache
826
+
820
827
  # hyper connection reducing of streams
821
828
 
822
829
  x = self.reduce_streams(x)
@@ -829,7 +836,7 @@ class MemoryAsContextTransformer(Module):
829
836
 
830
837
  x, _ = inverse_pack_mems(x)
831
838
 
832
- x = inverse_segment(x)
839
+ x = inverse_segment(x, remove_pad = False)
833
840
 
834
841
  x = x[:, :seq_len]
835
842
 
@@ -843,7 +850,7 @@ class MemoryAsContextTransformer(Module):
843
850
  if not return_cache:
844
851
  return logits
845
852
 
846
- return logits, (next_kv_caches, next_neural_mem_caches)
853
+ return logits, next_cache
847
854
 
848
855
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
849
856
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.30
3
+ Version: 0.1.31
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
4
+ titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
+ titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
6
+ titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.31.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
4
- titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
5
- titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
6
- titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.30.dist-info/RECORD,,