titans-pytorch 0.3.12__py3-none-any.whl → 0.3.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -493,12 +493,20 @@ class MemoryAsContextTransformer(Module):
493
493
  use_flex_attn = False,
494
494
  sliding_window_attn = False,
495
495
  neural_mem_weight_residual = False,
496
+ token_emb: Module | None = None,
497
+ abs_pos_emb: Module | None = None
496
498
  ):
497
499
  super().__init__()
498
500
 
499
- self.token_emb = nn.Embedding(num_tokens, dim)
501
+ if not exists(token_emb):
502
+ token_emb = nn.Embedding(num_tokens, dim)
500
503
 
501
- self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
504
+ self.token_emb = token_emb
505
+
506
+ if not exists(abs_pos_emb):
507
+ abs_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
508
+
509
+ self.abs_pos_emb = abs_pos_emb
502
510
 
503
511
  # long term mem tokens
504
512
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.12
3
+ Version: 0.3.14
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
3
+ titans_pytorch/mac_transformer.py,sha256=F04B88GaH0wHseUIWaX6VFhOSsk_3XDQ1E8e6pvqKgQ,25170
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
5
  titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
- titans_pytorch-0.3.12.dist-info/METADATA,sha256=02OsMYNITFLjnKJgis8eUHxwcdH2aVbA_D-QK24TYbg,6817
7
- titans_pytorch-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.12.dist-info/RECORD,,
6
+ titans_pytorch-0.3.14.dist-info/METADATA,sha256=1reoUZhvKaFPR6U0QXqJOziyss0HwHhwfJUf7oU7t-s,6817
7
+ titans_pytorch-0.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.14.dist-info/RECORD,,