titans-pytorch 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -41,6 +41,9 @@ def exists(v):
41
41
  def default(v, d):
42
42
  return v if exists(v) else d
43
43
 
44
+ def xnor(x, y):
45
+ return not (x ^ y)
46
+
44
47
  def identity(t):
45
48
  return t
46
49
 
@@ -365,6 +368,8 @@ class NeuralMemory(Module):
365
368
  momentum = True,
366
369
  pre_rmsnorm = True,
367
370
  post_rmsnorm = True,
371
+ qk_rmsnorm = False,
372
+ accept_value_residual = False,
368
373
  learned_mem_model_weights = True,
369
374
  max_grad_norm: float | None = None,
370
375
  use_accelerated_scan = False,
@@ -389,13 +394,16 @@ class NeuralMemory(Module):
389
394
 
390
395
  self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
391
396
 
397
+ self.q_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
398
+ self.k_norm = MultiheadRMSNorm(dim_head, heads) if qk_rmsnorm else nn.Identity()
399
+
392
400
  # maybe multi-headed
393
401
 
394
402
  dim_inner = dim_head * heads
395
403
 
396
404
  self.heads = heads
397
405
 
398
- self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
406
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
399
407
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
400
408
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
401
409
 
@@ -444,6 +452,14 @@ class NeuralMemory(Module):
444
452
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
445
453
  self.store_memory_loss_fn = store_memory_loss_fn
446
454
 
455
+ # value residual learning
456
+
457
+ self.learned_value_residual = Sequential(
458
+ LinearNoBias(dim, heads),
459
+ Rearrange('b n h -> b h n 1'),
460
+ nn.Sigmoid()
461
+ ) if accept_value_residual else None
462
+
447
463
  # empty memory embed
448
464
 
449
465
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
@@ -525,8 +541,11 @@ class NeuralMemory(Module):
525
541
  seq,
526
542
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
527
543
  return_aux_kv_loss = False,
528
- chunk_size = None
544
+ chunk_size = None,
545
+ value_residual = None
529
546
  ):
547
+ assert xnor(exists(value_residual), exists(self.learned_value_residual))
548
+
530
549
  seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
531
550
 
532
551
  # handle edge case
@@ -577,9 +596,21 @@ class NeuralMemory(Module):
577
596
 
578
597
  batch = keys.shape[0]
579
598
 
599
+ # maybe qk rmsnorm
600
+
601
+ keys = self.k_norm(keys)
602
+
603
+ # maybe value residual learning
604
+
605
+ orig_values = values
606
+
607
+ if exists(self.learned_value_residual):
608
+ mix = self.learned_value_residual(seq)
609
+ values = values.lerp(value_residual, mix)
610
+
580
611
  # take care of chunking
581
612
 
582
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
613
+ keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
583
614
 
584
615
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
585
616
 
@@ -637,10 +668,12 @@ class NeuralMemory(Module):
637
668
 
638
669
  last_update = updates.apply(lambda t: t[:, -1])
639
670
 
671
+ output = (updates, orig_values)
672
+
640
673
  if not return_aux_kv_loss:
641
- return updates
674
+ return output
642
675
 
643
- return updates, aux_kv_recon_loss.mean()
676
+ return output, aux_kv_recon_loss.mean()
644
677
 
645
678
  def retrieve_memories(
646
679
  self,
@@ -683,10 +716,14 @@ class NeuralMemory(Module):
683
716
 
684
717
  queries = self.split_heads(queries)
685
718
 
719
+ # maybe qk rmsnorm
720
+
721
+ queries = self.q_norm(queries)
722
+
686
723
  # fetch values from memory model
687
724
 
688
725
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
689
- queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
726
+ queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
690
727
 
691
728
  # forward functional call
692
729
 
@@ -723,7 +760,8 @@ class NeuralMemory(Module):
723
760
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
724
761
  return_aux_kv_loss = False,
725
762
  chunk_size = None,
726
- store_chunk_size = None
763
+ store_chunk_size = None,
764
+ return_values = False
727
765
  ):
728
766
  batch, seq_len = seq.shape[:2]
729
767
 
@@ -744,13 +782,18 @@ class NeuralMemory(Module):
744
782
  store_seq = default(store_seq, seq)
745
783
  store_chunk_size = default(store_chunk_size, chunk_size)
746
784
 
747
- updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
785
+ (updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
748
786
 
749
787
  past_weights, _ = past_state
750
788
 
751
789
  retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
752
790
 
791
+ output = retrieved
792
+
793
+ if return_values:
794
+ output = (retrieved, values)
795
+
753
796
  if not return_aux_kv_loss:
754
- return retrieved
797
+ return output
755
798
 
756
- return retrieved, aux_kv_recon_loss
799
+ return output, aux_kv_recon_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.20
3
+ Version: 0.1.22
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
78
78
 
79
79
  mem = NeuralMemory(
80
80
  dim = 384,
81
- chunk_size = 64,
82
- pre_rmsnorm = True
81
+ chunk_size = 64
83
82
  ).cuda()
84
83
 
85
84
  seq = torch.randn(2, 1024, 384).cuda()
@@ -196,3 +195,15 @@ $ python train_mac.py
196
195
  year = {2024}
197
196
  }
198
197
  ```
198
+
199
+ ```bibtex
200
+ @misc{wang2025testtimeregressionunifyingframework,
201
+ title = {Test-time regression: a unifying framework for designing sequence models with associative memory},
202
+ author = {Ke Alexander Wang and Jiaxin Shi and Emily B. Fox},
203
+ year = {2025},
204
+ eprint = {2501.12352},
205
+ archivePrefix = {arXiv},
206
+ primaryClass = {cs.LG},
207
+ url = {https://arxiv.org/abs/2501.12352},
208
+ }
209
+ ```
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
+ titans_pytorch/titans.py,sha256=7PGnZxdKq6T6e51RL7-QV43wp-46YmrytTZLt0McMco,23407
5
+ titans_pytorch-0.1.22.dist-info/METADATA,sha256=HCOAqLK605c8R2mvgQ80kwE9jayZ2CwJqHLsJtFx7Vs,6718
6
+ titans_pytorch-0.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.22.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
- titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
5
- titans_pytorch-0.1.20.dist-info/METADATA,sha256=Y0TmkfpKQ4LAyhr6SmAGeLHs3H4ZiZ4lg-gevvUDmjI,6340
6
- titans_pytorch-0.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.20.dist-info/RECORD,,