titans-pytorch 0.0.63__py3-none-any.whl → 0.0.64__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -153,6 +153,7 @@ class SegmentedAttention(Module):
153
153
  self.num_longterm_mem_tokens = num_longterm_mem_tokens
154
154
 
155
155
  total_segment_len = segment_len + num_longterm_mem_tokens
156
+ self.total_segment_len = total_segment_len
156
157
 
157
158
  self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
158
159
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
@@ -209,7 +210,7 @@ class SegmentedAttention(Module):
209
210
  # prep flex attention
210
211
 
211
212
  if not exists(flex_attn_fn):
212
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
213
+ block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens)
213
214
 
214
215
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
215
216
 
@@ -241,7 +242,6 @@ class SegmentedAttention(Module):
241
242
  batch, seq_len = seq.shape[:2]
242
243
 
243
244
  # auto pad to multiple
244
- # todo - get rid of logic with flex attention
245
245
 
246
246
  seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
247
247
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.63
3
+ Version: 0.0.64
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=7voYtbD_ErCD_JjvwhAiunUWtSIsIxGJAaf2aRB3c2s,15349
4
+ titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
+ titans_pytorch-0.0.64.dist-info/METADATA,sha256=K63jobSfTdn-aFpEpZgolu4zSIvgUzF2rDuoCHGXkgE,4457
6
+ titans_pytorch-0.0.64.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.64.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.64.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
4
- titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
5
- titans_pytorch-0.0.63.dist-info/METADATA,sha256=-CImQ-4hVNDFWczTb0V1dWL0QkHS-1c6XyntI1ULrms,4457
6
- titans_pytorch-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.63.dist-info/RECORD,,