titans-pytorch 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -111,6 +111,7 @@ def pad_and_segment_with_inverse(
111
111
  seq,
112
112
  segment_len,
113
113
  fold_into_batch = True,
114
+ inverse_remove_pad = True
114
115
  ):
115
116
  batch, seq_len = seq.shape[:2]
116
117
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
@@ -124,15 +125,12 @@ def pad_and_segment_with_inverse(
124
125
  if fold_into_batch:
125
126
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
126
127
 
127
- shape = seq.shape
128
-
129
128
  def inverse(out):
130
- unchanged_shape = out.shape == shape
131
129
 
132
130
  if fold_into_batch:
133
131
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
134
132
 
135
- if needs_pad and unchanged_shape:
133
+ if needs_pad and inverse_remove_pad:
136
134
  out = out[..., :-padding, :]
137
135
 
138
136
  return out
@@ -714,7 +712,7 @@ class MemoryAsContextTransformer(Module):
714
712
 
715
713
  # intersperse longterm memory
716
714
 
717
- x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
715
+ x, inverse_segment = pad_and_segment_with_inverse(x, segment_len, inverse_remove_pad = False)
718
716
 
719
717
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
720
718
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
@@ -856,7 +854,9 @@ class MemoryAsContextTransformer(Module):
856
854
 
857
855
  next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
858
856
 
859
- if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
857
+ kv_cache_length = next_kv_caches.shape[-2]
858
+
859
+ if not self.sliding_window_attn and divisible_by(kv_cache_length, attn_window_size):
860
860
  next_kv_caches = next_kv_caches[..., 0:0, :]
861
861
 
862
862
  next_cache = (
@@ -878,7 +878,7 @@ class MemoryAsContextTransformer(Module):
878
878
 
879
879
  if not is_inferencing:
880
880
 
881
- x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size)
881
+ x, inverse_segment = pad_and_segment_with_inverse(x, attn_window_size, inverse_remove_pad = False)
882
882
 
883
883
  x, _ = inverse_pack_mems(x)
884
884
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.38
3
+ Version: 0.2.0
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
56
56
 
57
57
  <img src="./fig1.png" width="400px"></img>
58
58
 
59
- ## Titans - Pytorch (wip)
59
+ ## Titans - Pytorch
60
60
 
61
61
  Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
62
62
 
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=p5kt6Zl3iuOEE9oGfdBC-M9IYHq8RnDLGlk2FLcZRLQ,26539
4
+ titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
+ titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
6
+ titans_pytorch-0.2.0.dist-info/METADATA,sha256=4-P4F9exUHZNdyEEx1eqTuu3-73-zBVwTGuf0zE04-g,6819
7
+ titans_pytorch-0.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.2.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.2.0.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=KBwo-Fr_fDzVaAa7xg1ggEpNlE4vRUoGMEjB-I2ZWTU,26463
4
- titans_pytorch/memory_models.py,sha256=LI9T36XB6YXIvvGWRw0ZMDlGpRC6KIv03OPzME2VAaU,3772
5
- titans_pytorch/neural_memory.py,sha256=vmKPOAlXBPXBnYPODrg_reWaIcr1xwtfQmuptGS6e5A,25559
6
- titans_pytorch-0.1.38.dist-info/METADATA,sha256=8ZmlPJotNIMGAqW8nYWJiM06MvCXJ2SKTGVKarWeOAQ,6826
7
- titans_pytorch-0.1.38.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.1.38.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.1.38.dist-info/RECORD,,