titans-pytorch 0.0.55__py3-none-any.whl → 0.0.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -445,8 +445,7 @@ class MemoryAsContextTransformer(Module):
445
445
  flex_attn_fn = None
446
446
 
447
447
  if use_flex_attn:
448
- block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
449
-
448
+ block_mask = create_mac_block_mask(seq_len_with_mem, self.segment_len, self.num_persist_mem_tokens)
450
449
  flex_attn_fn = partial(flex_attention, block_mask = block_mask)
451
450
 
452
451
  # value residual
@@ -467,7 +466,12 @@ class MemoryAsContextTransformer(Module):
467
466
  x, aux_kv_loss = maybe_neural_mem(x, return_aux_kv_loss = True)
468
467
  kv_recon_losses = kv_recon_losses + aux_kv_loss
469
468
 
470
- x, values = attn(x, value_residual = value_residual, disable_flex_attn = disable_flex_attn, flex_attn_fn = flex_attn_fn)
469
+ x, values = attn(
470
+ x,
471
+ value_residual = value_residual,
472
+ disable_flex_attn = disable_flex_attn,
473
+ flex_attn_fn = flex_attn_fn
474
+ )
471
475
 
472
476
  value_residual = default(value_residual, values)
473
477
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.55
3
+ Version: 0.0.56
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=_Vsco5YuR6uxouWcjFj-s-zPhrBcaapIzqoyi7qqY0Q,14245
4
+ titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
+ titans_pytorch-0.0.56.dist-info/METADATA,sha256=QlCmHqajHiaZTps0W9gKXIHE6dShZER3PqPoYi2zRe4,4457
6
+ titans_pytorch-0.0.56.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.0.56.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.0.56.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Q0MQA3RS8vqzs-KzSGZkvLR7afQ6ZW9uMOq1MeNuFoY,170
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=-VN8bURUaqHXH_96UqGYDhWcfgCaFdHGdM6faVuYDgQ,14159
4
- titans_pytorch/titans.py,sha256=ZKm-LnVKh1Cxs2tSxr4CcY37KroOOmYtTFM2F3Zb8Xg,16122
5
- titans_pytorch-0.0.55.dist-info/METADATA,sha256=VYP1B5d9tejIXr7u6ML4cSjvgIDlWYyp5KTyydlUqV8,4457
6
- titans_pytorch-0.0.55.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.0.55.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.0.55.dist-info/RECORD,,