titans-pytorch 0.1.33__tar.gz → 0.1.34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.33
3
+ Version: 0.1.34
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.33"
3
+ version = "0.1.34"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -86,25 +86,32 @@ def test_retrieve_store_diff_seq():
86
86
  assert retrieve_seq.shape == retrieved.shape
87
87
 
88
88
  def test_weight_tied_mlp_neural_mem():
89
+ from titans_pytorch import MemoryMLP
90
+
91
+ mlp = MemoryMLP(64, depth = 2)
92
+
89
93
  mem = NeuralMemory(
90
94
  dim = 384,
91
95
  dim_head = 64,
92
96
  heads = 2,
93
- chunk_size = 2
97
+ chunk_size = 2,
98
+ model = mlp
94
99
  )
95
100
 
96
101
  mem2 = NeuralMemory(
97
102
  dim = 384,
98
103
  dim_head = 64,
99
104
  heads = 2,
100
- chunk_size = 2
105
+ chunk_size = 2,
106
+ model = mlp
101
107
  )
102
108
 
103
109
  mem3 = NeuralMemory(
104
110
  dim = 384,
105
111
  dim_head = 64,
106
112
  heads = 2,
107
- chunk_size = 2
113
+ chunk_size = 2,
114
+ model = mlp
108
115
  )
109
116
 
110
117
  seq = torch.randn(2, 128, 384)
@@ -113,6 +120,31 @@ def test_weight_tied_mlp_neural_mem():
113
120
  seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
114
121
  seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
115
122
 
123
+ def test_mac_with_weight_tied_neural_mem():
124
+ from titans_pytorch import MemoryMLP, MemoryAsContextTransformer
125
+
126
+ transformer = MemoryAsContextTransformer(
127
+ num_tokens = 256,
128
+ dim = 256,
129
+ depth = 2,
130
+ segment_len = 2,
131
+ num_persist_mem_tokens = 0,
132
+ num_longterm_mem_tokens = 2,
133
+ neural_memory_segment_len = 2,
134
+ sliding_window_attn = True,
135
+ neural_memory_layers = (1, 2),
136
+ neural_memory_model = MemoryMLP(256, depth = 1),
137
+ num_residual_streams = 4,
138
+ weight_tie_memory_model = True,
139
+ neural_mem_gate_attn_output = True,
140
+ )
141
+
142
+
143
+ ids = torch.randint(0, 256, (1, 1023))
144
+ logits = transformer(ids)
145
+
146
+ assert logits.shape == (1, 1023, 256)
147
+
116
148
  def test_overriding_chunk_size():
117
149
  mem = NeuralMemory(
118
150
  dim = 384,
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  from math import ceil
5
+ from copy import deepcopy
5
6
  from functools import partial
6
7
  from collections import namedtuple
7
8
 
@@ -485,11 +486,13 @@ class MemoryAsContextTransformer(Module):
485
486
  heads = 8,
486
487
  ff_mult = 4,
487
488
  num_residual_streams = 4,
489
+ neural_memory_model: Module | None = None,
488
490
  neural_memory_kwargs: dict = dict(),
489
491
  neural_memory_layers: tuple[int, ...] | None = None,
490
492
  aux_kv_recon_loss_weight = 0.,
491
493
  use_flex_attn = False,
492
- sliding_window_attn = False
494
+ sliding_window_attn = False,
495
+ weight_tie_memory_model = False
493
496
  ):
494
497
  super().__init__()
495
498
 
@@ -523,6 +526,15 @@ class MemoryAsContextTransformer(Module):
523
526
 
524
527
  neural_memory_layers = default(neural_memory_layers, layers)
525
528
 
529
+ # weight tying neural memory model
530
+
531
+ maybe_copy = deepcopy if not weight_tie_memory_model else identity
532
+
533
+ if weight_tie_memory_model:
534
+ assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
535
+
536
+ self.weight_tie_memory_model = weight_tie_memory_model
537
+
526
538
  # mem, attn, and feedforward layers
527
539
 
528
540
  for layer in layers:
@@ -551,6 +563,7 @@ class MemoryAsContextTransformer(Module):
551
563
  mem = NeuralMemory(
552
564
  dim = dim,
553
565
  chunk_size = self.neural_memory_segment_len,
566
+ model = maybe_copy(neural_memory_model),
554
567
  **neural_memory_kwargs
555
568
  )
556
569
 
@@ -683,7 +696,7 @@ class MemoryAsContextTransformer(Module):
683
696
 
684
697
  # math
685
698
 
686
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
699
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
687
700
 
688
701
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
689
702
 
@@ -736,6 +749,10 @@ class MemoryAsContextTransformer(Module):
736
749
  next_kv_caches = []
737
750
  next_neural_mem_caches = []
738
751
 
752
+ # weight tied neural memory
753
+
754
+ neural_memory_updates = None
755
+
739
756
  # value residual
740
757
 
741
758
  value_residual = None
@@ -769,7 +786,8 @@ class MemoryAsContextTransformer(Module):
769
786
  if not is_inferencing:
770
787
  (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
771
788
  mem_input,
772
- return_aux_kv_loss = True
789
+ return_aux_kv_loss = True,
790
+ prev_layer_updates = neural_memory_updates
773
791
  )
774
792
 
775
793
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
@@ -777,9 +795,13 @@ class MemoryAsContextTransformer(Module):
777
795
  else:
778
796
  retrieved, next_neural_mem_cache = mem.forward_inference(
779
797
  mem_input,
780
- state = next(neural_mem_caches, None)
798
+ state = next(neural_mem_caches, None),
799
+ prev_layer_updates = neural_memory_updates
781
800
  )
782
801
 
802
+ if weight_tie_memory_model:
803
+ neural_memory_updates = next_neural_mem_cache.updates
804
+
783
805
  if self.gate_attn_output:
784
806
  attn_out_gates = retrieved.sigmoid()
785
807
  else:
@@ -822,6 +822,7 @@ class NeuralMemory(Module):
822
822
  self,
823
823
  token: Tensor,
824
824
  state = None,
825
+ prev_layer_updates: dict[str, Tensor] | None = None
825
826
  ):
826
827
 
827
828
  # unpack previous state
@@ -863,12 +864,17 @@ class NeuralMemory(Module):
863
864
  else:
864
865
  updates = updates.apply(lambda t: t[:, -1:])
865
866
 
867
+ if exists(prev_layer_updates):
868
+ prev_layer_updates = TensorDict(prev_layer_updates)
869
+ prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
870
+
866
871
  if store_seq_cache_len == self.chunk_size:
867
872
 
868
873
  next_updates, next_states, _ = self.store_memories(
869
874
  cache_store_seq,
870
875
  weights,
871
- past_state = past_states
876
+ past_state = past_states,
877
+ prev_layer_updates = prev_layer_updates,
872
878
  )
873
879
 
874
880
  updates = next_updates
File without changes