titans-pytorch 0.3.12__py3-none-any.whl → 0.3.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -493,10 +493,16 @@ class MemoryAsContextTransformer(Module):
493
493
  use_flex_attn = False,
494
494
  sliding_window_attn = False,
495
495
  neural_mem_weight_residual = False,
496
+ token_emb: Module | None = None,
496
497
  ):
497
498
  super().__init__()
498
499
 
499
- self.token_emb = nn.Embedding(num_tokens, dim)
500
+ if not exists(token_emb):
501
+ token_emb = nn.Embedding(num_tokens, dim)
502
+
503
+ self.token_emb = token_emb
504
+
505
+ # absolute positions
500
506
 
501
507
  self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
502
508
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.12
3
+ Version: 0.3.15
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,9 +1,9 @@
1
1
  titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=EyqA53HBqvAr4UNZUs37LR6IltyEfA7FKEV54YzVYlg,24945
3
+ titans_pytorch/mac_transformer.py,sha256=HIB3S3JBA8Fe1EBITvDZSHXtn-1_fF1rwlw-MzqagKY,25085
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
5
  titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
- titans_pytorch-0.3.12.dist-info/METADATA,sha256=02OsMYNITFLjnKJgis8eUHxwcdH2aVbA_D-QK24TYbg,6817
7
- titans_pytorch-0.3.12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.12.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.12.dist-info/RECORD,,
6
+ titans_pytorch-0.3.15.dist-info/METADATA,sha256=RPw9JXenAI7cGpVP3hQZlj0OA5-xsvXvXHvxyhWdgpg,6817
7
+ titans_pytorch-0.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.15.dist-info/RECORD,,