titans-pytorch 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -365,6 +365,7 @@ class NeuralMemory(Module):
365
365
  momentum = True,
366
366
  pre_rmsnorm = True,
367
367
  post_rmsnorm = True,
368
+ qk_rmsnorm = False,
368
369
  learned_mem_model_weights = True,
369
370
  max_grad_norm: float | None = None,
370
371
  use_accelerated_scan = False,
@@ -389,6 +390,9 @@ class NeuralMemory(Module):
389
390
 
390
391
  self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
391
392
 
393
+ self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
394
+ self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
395
+
392
396
  # maybe multi-headed
393
397
 
394
398
  dim_inner = dim_head * heads
@@ -577,6 +581,10 @@ class NeuralMemory(Module):
577
581
 
578
582
  batch = keys.shape[0]
579
583
 
584
+ # maybe qk rmsnorm
585
+
586
+ keys = self.k_norm(keys)
587
+
580
588
  # take care of chunking
581
589
 
582
590
  keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
@@ -683,6 +691,10 @@ class NeuralMemory(Module):
683
691
 
684
692
  queries = self.split_heads(queries)
685
693
 
694
+ # maybe qk rmsnorm
695
+
696
+ queries = self.q_norm(queries)
697
+
686
698
  # fetch values from memory model
687
699
 
688
700
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.20
3
+ Version: 0.1.21
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -196,3 +196,15 @@ $ python train_mac.py
196
196
  year = {2024}
197
197
  }
198
198
  ```
199
+
200
+ ```bibtex
201
+ @misc{wang2025testtimeregressionunifyingframework,
202
+ title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
203
+ author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
204
+ year = {2025},
205
+ eprint = {2501.12352},
206
+ archivePrefix = {arXiv},
207
+ primaryClass = {cs.LG},
208
+ url = {https://arxiv.org/abs/2501.12352},
209
+ }
210
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
+ titans_pytorch/titans.py,sha256=YYt6O5EiBVvyxWM4R1JuLLJH3bGm1V-74aB7VhbsWQ0,22577
5
+ titans_pytorch-0.1.21.dist-info/METADATA,sha256=ixbJisycB0MgSIcOvDRM1PIMs3l1TM_AmQ88aWZYEsM,6742
6
+ titans_pytorch-0.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.21.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
- titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
5
- titans_pytorch-0.1.20.dist-info/METADATA,sha256=Y0TmkfpKQ4LAyhr6SmAGeLHs3H4ZiZ4lg-gevvUDmjI,6340
6
- titans_pytorch-0.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.20.dist-info/RECORD,,