titans-pytorch 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +53 -10
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.22.dist-info}/METADATA +14 -3
- titans_pytorch-0.1.22.dist-info/RECORD +8 -0
- titans_pytorch-0.1.20.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.22.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.20.dist-info → titans_pytorch-0.1.22.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -41,6 +41,9 @@ def exists(v):
|
|
41
41
|
def default(v, d):
|
42
42
|
return v if exists(v) else d
|
43
43
|
|
44
|
+
def xnor(x, y):
|
45
|
+
return not (x ^ y)
|
46
|
+
|
44
47
|
def identity(t):
|
45
48
|
return t
|
46
49
|
|
@@ -365,6 +368,8 @@ class NeuralMemory(Module):
|
|
365
368
|
momentum = True,
|
366
369
|
pre_rmsnorm = True,
|
367
370
|
post_rmsnorm = True,
|
371
|
+
qk_rmsnorm = False,
|
372
|
+
accept_value_residual = False,
|
368
373
|
learned_mem_model_weights = True,
|
369
374
|
max_grad_norm: float | None = None,
|
370
375
|
use_accelerated_scan = False,
|
@@ -389,13 +394,16 @@ class NeuralMemory(Module):
|
|
389
394
|
|
390
395
|
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
|
391
396
|
|
397
|
+
self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
|
398
|
+
self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
|
399
|
+
|
392
400
|
# maybe multi-headed
|
393
401
|
|
394
402
|
dim_inner = dim_head * heads
|
395
403
|
|
396
404
|
self.heads = heads
|
397
405
|
|
398
|
-
self.split_heads = Rearrange('b n (h d) ->
|
406
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
399
407
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
400
408
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
401
409
|
|
@@ -444,6 +452,14 @@ class NeuralMemory(Module):
|
|
444
452
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
445
453
|
self.store_memory_loss_fn = store_memory_loss_fn
|
446
454
|
|
455
|
+
# value residual learning
|
456
|
+
|
457
|
+
self.learned_value_residual = Sequential(
|
458
|
+
LinearNoBias(dim, heads),
|
459
|
+
Rearrange('b n h -> b h n 1'),
|
460
|
+
nn.Sigmoid()
|
461
|
+
) if accept_value_residual else None
|
462
|
+
|
447
463
|
# empty memory embed
|
448
464
|
|
449
465
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
@@ -525,8 +541,11 @@ class NeuralMemory(Module):
|
|
525
541
|
seq,
|
526
542
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
527
543
|
return_aux_kv_loss = False,
|
528
|
-
chunk_size = None
|
544
|
+
chunk_size = None,
|
545
|
+
value_residual = None
|
529
546
|
):
|
547
|
+
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
548
|
+
|
530
549
|
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
531
550
|
|
532
551
|
# handle edge case
|
@@ -577,9 +596,21 @@ class NeuralMemory(Module):
|
|
577
596
|
|
578
597
|
batch = keys.shape[0]
|
579
598
|
|
599
|
+
# maybe qk rmsnorm
|
600
|
+
|
601
|
+
keys = self.k_norm(keys)
|
602
|
+
|
603
|
+
# maybe value residual learning
|
604
|
+
|
605
|
+
orig_values = values
|
606
|
+
|
607
|
+
if exists(self.learned_value_residual):
|
608
|
+
mix = self.learned_value_residual(seq)
|
609
|
+
values = values.lerp(value_residual, mix)
|
610
|
+
|
580
611
|
# take care of chunking
|
581
612
|
|
582
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
613
|
+
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
583
614
|
|
584
615
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
585
616
|
|
@@ -637,10 +668,12 @@ class NeuralMemory(Module):
|
|
637
668
|
|
638
669
|
last_update = updates.apply(lambda t: t[:, -1])
|
639
670
|
|
671
|
+
output = (updates, orig_values)
|
672
|
+
|
640
673
|
if not return_aux_kv_loss:
|
641
|
-
return
|
674
|
+
return output
|
642
675
|
|
643
|
-
return
|
676
|
+
return output, aux_kv_recon_loss.mean()
|
644
677
|
|
645
678
|
def retrieve_memories(
|
646
679
|
self,
|
@@ -683,10 +716,14 @@ class NeuralMemory(Module):
|
|
683
716
|
|
684
717
|
queries = self.split_heads(queries)
|
685
718
|
|
719
|
+
# maybe qk rmsnorm
|
720
|
+
|
721
|
+
queries = self.q_norm(queries)
|
722
|
+
|
686
723
|
# fetch values from memory model
|
687
724
|
|
688
725
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
689
|
-
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
726
|
+
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
690
727
|
|
691
728
|
# forward functional call
|
692
729
|
|
@@ -723,7 +760,8 @@ class NeuralMemory(Module):
|
|
723
760
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
724
761
|
return_aux_kv_loss = False,
|
725
762
|
chunk_size = None,
|
726
|
-
store_chunk_size = None
|
763
|
+
store_chunk_size = None,
|
764
|
+
return_values = False
|
727
765
|
):
|
728
766
|
batch, seq_len = seq.shape[:2]
|
729
767
|
|
@@ -744,13 +782,18 @@ class NeuralMemory(Module):
|
|
744
782
|
store_seq = default(store_seq, seq)
|
745
783
|
store_chunk_size = default(store_chunk_size, chunk_size)
|
746
784
|
|
747
|
-
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
785
|
+
(updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
748
786
|
|
749
787
|
past_weights, _ = past_state
|
750
788
|
|
751
789
|
retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
|
752
790
|
|
791
|
+
output = retrieved
|
792
|
+
|
793
|
+
if return_values:
|
794
|
+
output = (retrieved, values)
|
795
|
+
|
753
796
|
if not return_aux_kv_loss:
|
754
|
-
return
|
797
|
+
return output
|
755
798
|
|
756
|
-
return
|
799
|
+
return output, aux_kv_recon_loss
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.22
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
|
|
78
78
|
|
79
79
|
mem = NeuralMemory(
|
80
80
|
dim = 384,
|
81
|
-
chunk_size = 64
|
82
|
-
pre_rmsnorm = True
|
81
|
+
chunk_size = 64
|
83
82
|
).cuda()
|
84
83
|
|
85
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -196,3 +195,15 @@ $ python train_mac.py
|
|
196
195
|
year = {2024}
|
197
196
|
}
|
198
197
|
```
|
198
|
+
|
199
|
+
```bibtex
|
200
|
+
@misc{wang2025testtimeregressionunifyingframework,
|
201
|
+
title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
|
202
|
+
author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
|
203
|
+
year = {2025},
|
204
|
+
eprint = {2501.12352},
|
205
|
+
archivePrefix = {arXiv},
|
206
|
+
primaryClass = {cs.LG},
|
207
|
+
url = {https://arxiv.org/abs/2501.12352},
|
208
|
+
}
|
209
|
+
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
+
titans_pytorch/titans.py,sha256=7PGnZxdKq6T6e51RL7-QV43wp-46YmrytTZLt0McMco,23407
|
5
|
+
titans_pytorch-0.1.22.dist-info/METADATA,sha256=HCOAqLK605c8R2mvgQ80kwE9jayZ2CwJqHLsJtFx7Vs,6718
|
6
|
+
titans_pytorch-0.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.22.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
-
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
5
|
-
titans_pytorch-0.1.20.dist-info/METADATA,sha256=Y0TmkfpKQ4LAyhr6SmAGeLHs3H4ZiZ4lg-gevvUDmjI,6340
|
6
|
-
titans_pytorch-0.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|