titans-pytorch 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
titans_pytorch/titans.py CHANGED
@@ -41,6 +41,9 @@ def exists(v):
41
41
  def default(v, d):
42
42
  return v if exists(v) else d
43
43
 
44
+ def xnor(x, y):
45
+ return not (x ^ y)
46
+
44
47
  def identity(t):
45
48
  return t
46
49
 
@@ -366,6 +369,7 @@ class NeuralMemory(Module):
366
369
  pre_rmsnorm = True,
367
370
  post_rmsnorm = True,
368
371
  qk_rmsnorm = False,
372
+ accept_value_residual = False,
369
373
  learned_mem_model_weights = True,
370
374
  max_grad_norm: float | None = None,
371
375
  use_accelerated_scan = False,
@@ -399,7 +403,7 @@ class NeuralMemory(Module):
399
403
 
400
404
  self.heads = heads
401
405
 
402
- self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
406
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
403
407
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
404
408
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
405
409
 
@@ -448,6 +452,14 @@ class NeuralMemory(Module):
448
452
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
449
453
  self.store_memory_loss_fn = store_memory_loss_fn
450
454
 
455
+ # value residual learning
456
+
457
+ self.learned_value_residual = Sequential(
458
+ LinearNoBias(dim, heads),
459
+ Rearrange('b n h -> b h n 1'),
460
+ nn.Sigmoid()
461
+ ) if accept_value_residual else None
462
+
451
463
  # empty memory embed
452
464
 
453
465
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
@@ -529,8 +541,11 @@ class NeuralMemory(Module):
529
541
  seq,
530
542
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
531
543
  return_aux_kv_loss = False,
532
- chunk_size = None
544
+ chunk_size = None,
545
+ value_residual = None
533
546
  ):
547
+ assert xnor(exists(value_residual), exists(self.learned_value_residual))
548
+
534
549
  seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
535
550
 
536
551
  # handle edge case
@@ -585,9 +600,17 @@ class NeuralMemory(Module):
585
600
 
586
601
  keys = self.k_norm(keys)
587
602
 
603
+ # maybe value residual learning
604
+
605
+ orig_values = values
606
+
607
+ if exists(self.learned_value_residual):
608
+ mix = self.learned_value_residual(seq)
609
+ values = values.lerp(value_residual, mix)
610
+
588
611
  # take care of chunking
589
612
 
590
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
613
+ keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
591
614
 
592
615
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
593
616
 
@@ -645,10 +668,12 @@ class NeuralMemory(Module):
645
668
 
646
669
  last_update = updates.apply(lambda t: t[:, -1])
647
670
 
671
+ output = (updates, orig_values)
672
+
648
673
  if not return_aux_kv_loss:
649
- return updates
674
+ return output
650
675
 
651
- return updates, aux_kv_recon_loss.mean()
676
+ return output, aux_kv_recon_loss.mean()
652
677
 
653
678
  def retrieve_memories(
654
679
  self,
@@ -698,7 +723,7 @@ class NeuralMemory(Module):
698
723
  # fetch values from memory model
699
724
 
700
725
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
701
- queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
726
+ queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
702
727
 
703
728
  # forward functional call
704
729
 
@@ -735,7 +760,8 @@ class NeuralMemory(Module):
735
760
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
736
761
  return_aux_kv_loss = False,
737
762
  chunk_size = None,
738
- store_chunk_size = None
763
+ store_chunk_size = None,
764
+ return_values = False
739
765
  ):
740
766
  batch, seq_len = seq.shape[:2]
741
767
 
@@ -756,13 +782,18 @@ class NeuralMemory(Module):
756
782
  store_seq = default(store_seq, seq)
757
783
  store_chunk_size = default(store_chunk_size, chunk_size)
758
784
 
759
- updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
785
+ (updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
760
786
 
761
787
  past_weights, _ = past_state
762
788
 
763
789
  retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
764
790
 
791
+ output = retrieved
792
+
793
+ if return_values:
794
+ output = (retrieved, values)
795
+
765
796
  if not return_aux_kv_loss:
766
- return retrieved
797
+ return output
767
798
 
768
- return retrieved, aux_kv_recon_loss
799
+ return output, aux_kv_recon_loss
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.21
3
+ Version: 0.1.22
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
78
78
 
79
79
  mem = NeuralMemory(
80
80
  dim = 384,
81
- chunk_size = 64,
82
- pre_rmsnorm = True
81
+ chunk_size = 64
83
82
  ).cuda()
84
83
 
85
84
  seq = torch.randn(2, 1024, 384).cuda()
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
+ titans_pytorch/titans.py,sha256=7PGnZxdKq6T6e51RL7-QV43wp-46YmrytTZLt0McMco,23407
5
+ titans_pytorch-0.1.22.dist-info/METADATA,sha256=HCOAqLK605c8R2mvgQ80kwE9jayZ2CwJqHLsJtFx7Vs,6718
6
+ titans_pytorch-0.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.22.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
4
- titans_pytorch/titans.py,sha256=YYt6O5EiBVvyxWM4R1JuLLJH3bGm1V-74aB7VhbsWQ0,22577
5
- titans_pytorch-0.1.21.dist-info/METADATA,sha256=ixbJisycB0MgSIcOvDRM1PIMs3l1TM_AmQ88aWZYEsM,6742
6
- titans_pytorch-0.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.21.dist-info/RECORD,,