titans-pytorch 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +12 -0
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.21.dist-info}/METADATA +13 -1
- titans_pytorch-0.1.21.dist-info/RECORD +8 -0
- titans_pytorch-0.1.20.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.21.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.21.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -365,6 +365,7 @@ class NeuralMemory(Module):
|
|
365
365
|
momentum = True,
|
366
366
|
pre_rmsnorm = True,
|
367
367
|
post_rmsnorm = True,
|
368
|
+
qk_rmsnorm = False,
|
368
369
|
learned_mem_model_weights = True,
|
369
370
|
max_grad_norm: float | None = None,
|
370
371
|
use_accelerated_scan = False,
|
@@ -389,6 +390,9 @@ class NeuralMemory(Module):
|
|
389
390
|
|
390
391
|
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
391
392
|
|
393
|
+
self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
|
394
|
+
self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
|
395
|
+
|
392
396
|
# maybe multi-headed
|
393
397
|
|
394
398
|
dim_inner = dim_head * heads
|
@@ -577,6 +581,10 @@ class NeuralMemory(Module):
|
|
577
581
|
|
578
582
|
batch = keys.shape[0]
|
579
583
|
|
584
|
+
# maybe qk rmsnorm
|
585
|
+
|
586
|
+
keys = self.k_norm(keys)
|
587
|
+
|
580
588
|
# take care of chunking
|
581
589
|
|
582
590
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
@@ -683,6 +691,10 @@ class NeuralMemory(Module):
|
|
683
691
|
|
684
692
|
queries = self.split_heads(queries)
|
685
693
|
|
694
|
+
# maybe qk rmsnorm
|
695
|
+
|
696
|
+
queries = self.q_norm(queries)
|
697
|
+
|
686
698
|
# fetch values from memory model
|
687
699
|
|
688
700
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.21
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -196,3 +196,15 @@ $ python train_mac.py
|
|
196
196
|
year = {2024}
|
197
197
|
}
|
198
198
|
```
|
199
|
+
|
200
|
+
```bibtex
|
201
|
+
@misc{wang2025testtimeregressionunifyingframework,
|
202
|
+
title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
|
203
|
+
author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
|
204
|
+
year = {2025},
|
205
|
+
eprint = {2501.12352},
|
206
|
+
archivePrefix = {arXiv},
|
207
|
+
primaryClass = {cs.LG},
|
208
|
+
url = {https://arxiv.org/abs/2501.12352},
|
209
|
+
}
|
210
|
+
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
+
titans_pytorch/titans.py,sha256=YYt6O5EiBVvyxWM4R1JuLLJH3bGm1V-74aB7VhbsWQ0,22577
|
5
|
+
titans_pytorch-0.1.21.dist-info/METADATA,sha256=ixbJisycB0MgSIcOvDRM1PIMs3l1TM_AmQ88aWZYEsM,6742
|
6
|
+
titans_pytorch-0.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.21.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
-
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
5
|
-
titans_pytorch-0.1.20.dist-info/METADATA,sha256=Y0TmkfpKQ4LAyhr6SmAGeLHs3H4ZiZ4lg-gevvUDmjI,6340
|
6
|
-
titans_pytorch-0.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|