titans-pytorch 0.1.33__tar.gz → 0.1.35__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.33
3
+ Version: 0.1.35
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.33"
3
+ version = "0.1.35"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -86,25 +86,32 @@ def test_retrieve_store_diff_seq():
86
86
  assert retrieve_seq.shape == retrieved.shape
87
87
 
88
88
  def test_weight_tied_mlp_neural_mem():
89
+ from titans_pytorch import MemoryMLP
90
+
91
+ mlp = MemoryMLP(64, depth = 2)
92
+
89
93
  mem = NeuralMemory(
90
94
  dim = 384,
91
95
  dim_head = 64,
92
96
  heads = 2,
93
- chunk_size = 2
97
+ chunk_size = 2,
98
+ model = mlp
94
99
  )
95
100
 
96
101
  mem2 = NeuralMemory(
97
102
  dim = 384,
98
103
  dim_head = 64,
99
104
  heads = 2,
100
- chunk_size = 2
105
+ chunk_size = 2,
106
+ model = mlp
101
107
  )
102
108
 
103
109
  mem3 = NeuralMemory(
104
110
  dim = 384,
105
111
  dim_head = 64,
106
112
  heads = 2,
107
- chunk_size = 2
113
+ chunk_size = 2,
114
+ model = mlp
108
115
  )
109
116
 
110
117
  seq = torch.randn(2, 128, 384)
@@ -113,6 +120,31 @@ def test_weight_tied_mlp_neural_mem():
113
120
  seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
114
121
  seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
115
122
 
123
+ def test_mac_with_weight_tied_neural_mem():
124
+ from titans_pytorch import MemoryMLP, MemoryAsContextTransformer
125
+
126
+ transformer = MemoryAsContextTransformer(
127
+ num_tokens = 256,
128
+ dim = 256,
129
+ depth = 2,
130
+ segment_len = 2,
131
+ num_persist_mem_tokens = 0,
132
+ num_longterm_mem_tokens = 2,
133
+ neural_memory_segment_len = 2,
134
+ sliding_window_attn = True,
135
+ neural_memory_layers = (1, 2),
136
+ neural_memory_model = MemoryMLP(256, depth = 1),
137
+ num_residual_streams = 4,
138
+ weight_tie_memory_model = True,
139
+ neural_mem_gate_attn_output = True,
140
+ )
141
+
142
+
143
+ ids = torch.randint(0, 256, (1, 1023))
144
+ logits = transformer(ids)
145
+
146
+ assert logits.shape == (1, 1023, 256)
147
+
116
148
  def test_overriding_chunk_size():
117
149
  mem = NeuralMemory(
118
150
  dim = 384,
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
  from typing import Callable
3
3
 
4
4
  from math import ceil
5
+ from copy import deepcopy
5
6
  from functools import partial
6
7
  from collections import namedtuple
7
8
 
@@ -478,18 +479,21 @@ class MemoryAsContextTransformer(Module):
478
479
  depth,
479
480
  segment_len,
480
481
  neural_memory_segment_len = None,
481
- neural_mem_gate_attn_output = True,
482
+ neural_mem_gate_attn_output = False,
483
+ neural_memory_add_value_residual = False,
482
484
  num_longterm_mem_tokens = 0,
483
485
  num_persist_mem_tokens = 0,
484
486
  dim_head = 64,
485
487
  heads = 8,
486
488
  ff_mult = 4,
487
489
  num_residual_streams = 4,
490
+ neural_memory_model: Module | None = None,
488
491
  neural_memory_kwargs: dict = dict(),
489
492
  neural_memory_layers: tuple[int, ...] | None = None,
490
493
  aux_kv_recon_loss_weight = 0.,
491
494
  use_flex_attn = False,
492
- sliding_window_attn = False
495
+ sliding_window_attn = False,
496
+ weight_tie_memory_model = False
493
497
  ):
494
498
  super().__init__()
495
499
 
@@ -523,6 +527,20 @@ class MemoryAsContextTransformer(Module):
523
527
 
524
528
  neural_memory_layers = default(neural_memory_layers, layers)
525
529
 
530
+ # weight tying neural memory model
531
+
532
+ maybe_copy = deepcopy if not weight_tie_memory_model else identity
533
+
534
+ if weight_tie_memory_model:
535
+ assert exists(neural_memory_model), '`neural_memory_model` must be explicitly set'
536
+
537
+ self.weight_tie_memory_model = weight_tie_memory_model
538
+
539
+ # value residual learning for neural memory
540
+
541
+ is_first_mem = True
542
+ self.mem_add_value_residual = neural_memory_add_value_residual
543
+
526
544
  # mem, attn, and feedforward layers
527
545
 
528
546
  for layer in layers:
@@ -551,9 +569,12 @@ class MemoryAsContextTransformer(Module):
551
569
  mem = NeuralMemory(
552
570
  dim = dim,
553
571
  chunk_size = self.neural_memory_segment_len,
572
+ model = maybe_copy(neural_memory_model),
573
+ accept_value_residual = not is_first_mem and neural_memory_add_value_residual,
554
574
  **neural_memory_kwargs
555
575
  )
556
576
 
577
+ is_first_mem = False
557
578
 
558
579
  ff = FeedForward(dim = dim, mult = ff_mult)
559
580
 
@@ -683,7 +704,7 @@ class MemoryAsContextTransformer(Module):
683
704
 
684
705
  # math
685
706
 
686
- batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
707
+ batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size, weight_tie_memory_model = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size, self.weight_tie_memory_model
687
708
 
688
709
  seq_len_with_mem = self.seq_len_with_longterm_mem(seq_len)
689
710
 
@@ -736,10 +757,16 @@ class MemoryAsContextTransformer(Module):
736
757
  next_kv_caches = []
737
758
  next_neural_mem_caches = []
738
759
 
760
+ # weight tied neural memory
761
+
762
+ neural_memory_updates = None
763
+
739
764
  # value residual
740
765
 
741
766
  value_residual = None
742
767
 
768
+ mem_value_residual = None
769
+
743
770
  # aux losses
744
771
 
745
772
  kv_recon_losses = self.zero
@@ -767,19 +794,31 @@ class MemoryAsContextTransformer(Module):
767
794
  mem_input, add_residual = mem_hyper_conn(x)
768
795
 
769
796
  if not is_inferencing:
770
- (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
797
+ (retrieved, next_neural_mem_cache, next_mem_value_residual), mem_kv_aux_loss = mem(
771
798
  mem_input,
772
- return_aux_kv_loss = True
799
+ return_aux_kv_loss = True,
800
+ return_values = True,
801
+ value_residual = mem_value_residual,
802
+ prev_layer_updates = neural_memory_updates
773
803
  )
774
804
 
775
805
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
776
806
 
777
807
  else:
778
- retrieved, next_neural_mem_cache = mem.forward_inference(
808
+ (retrieved, next_neural_mem_cache, next_mem_value_residual) = mem.forward_inference(
779
809
  mem_input,
780
- state = next(neural_mem_caches, None)
810
+ state = next(neural_mem_caches, None),
811
+ return_values = True,
812
+ value_residual = mem_value_residual,
813
+ prev_layer_updates = neural_memory_updates
781
814
  )
782
815
 
816
+ if self.mem_add_value_residual:
817
+ mem_value_residual = next_mem_value_residual
818
+
819
+ if weight_tie_memory_model:
820
+ neural_memory_updates = next_neural_mem_cache.updates
821
+
783
822
  if self.gate_attn_output:
784
823
  attn_out_gates = retrieved.sigmoid()
785
824
  else:
@@ -822,6 +822,9 @@ class NeuralMemory(Module):
822
822
  self,
823
823
  token: Tensor,
824
824
  state = None,
825
+ prev_layer_updates: dict[str, Tensor] | None = None,
826
+ return_values = False,
827
+ value_residual = None,
825
828
  ):
826
829
 
827
830
  # unpack previous state
@@ -863,12 +866,18 @@ class NeuralMemory(Module):
863
866
  else:
864
867
  updates = updates.apply(lambda t: t[:, -1:])
865
868
 
869
+ if exists(prev_layer_updates):
870
+ prev_layer_updates = TensorDict(prev_layer_updates)
871
+ prev_layer_updates = prev_layer_updates.apply(lambda t: t[:, -1:])
872
+
866
873
  if store_seq_cache_len == self.chunk_size:
867
874
 
868
- next_updates, next_states, _ = self.store_memories(
875
+ next_updates, next_states, values = self.store_memories(
869
876
  cache_store_seq,
870
877
  weights,
871
- past_state = past_states
878
+ past_state = past_states,
879
+ prev_layer_updates = prev_layer_updates,
880
+ value_residual = value_residual
872
881
  )
873
882
 
874
883
  updates = next_updates
@@ -882,7 +891,12 @@ class NeuralMemory(Module):
882
891
 
883
892
  next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
884
893
 
885
- return retrieved, next_state
894
+ output = (retrieved, next_state)
895
+
896
+ if return_values:
897
+ output = (*output, values)
898
+
899
+ return output
886
900
 
887
901
  def forward(
888
902
  self,
@@ -894,6 +908,7 @@ class NeuralMemory(Module):
894
908
  chunk_size = None,
895
909
  store_chunk_size = None,
896
910
  return_values = False,
911
+ value_residual = None,
897
912
  return_next_state = False,
898
913
  prev_layer_updates: dict[str, Tensor] | None = None
899
914
  ):
@@ -927,6 +942,7 @@ class NeuralMemory(Module):
927
942
  mem_model_weights,
928
943
  chunk_size = store_chunk_size,
929
944
  prev_layer_updates = prev_layer_updates,
945
+ value_residual = value_residual,
930
946
  return_aux_kv_loss = True
931
947
  )
932
948
 
@@ -9,7 +9,8 @@ from torch.nn import functional as F
9
9
  from torch.utils.data import DataLoader, Dataset
10
10
 
11
11
  from adam_atan2_pytorch import AdoptAtan2
12
- from titans_pytorch import MemoryAsContextTransformer
12
+
13
+ from titans_pytorch import MemoryAsContextTransformer, MemoryMLP
13
14
 
14
15
  # constants
15
16
 
@@ -29,13 +30,16 @@ SEQ_LEN = 512
29
30
  NEURAL_MEMORY_DEPTH = 2
30
31
  NUM_PERSIST_MEM = 4
31
32
  NUM_LONGTERM_MEM = 4
32
- NEURAL_MEM_LAYERS = (2, 4)
33
- NEURAL_MEM_GATE_ATTN_OUTPUT = True
33
+ NEURAL_MEM_LAYERS = (2, 4, 6) # layers 2, 4, 6 have neural memory, can add more
34
+ NEURAL_MEM_GATE_ATTN_OUTPUT = False
34
35
  NEURAL_MEM_MOMENTUM = True
36
+ NEURAL_MEM_ADD_VALUE_RESIDUAL = True,
35
37
  WINDOW_SIZE = 32
36
38
  NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
37
39
  SLIDING_WINDOWS = True
40
+ WEIGHT_TIE_MEMORY_MODEL = True # set to have memory MLP shared across layers
38
41
  STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
42
+ MEMORY_MODEL_PER_LAYER_LEARNED_LR = True
39
43
  KV_RECON_LOSS_WEIGHT = 0.
40
44
 
41
45
  # experiment related
@@ -84,15 +88,19 @@ model = MemoryAsContextTransformer(
84
88
  aux_kv_recon_loss_weight = KV_RECON_LOSS_WEIGHT,
85
89
  use_flex_attn = USE_FLEX_ATTN,
86
90
  sliding_window_attn = SLIDING_WINDOWS,
91
+ weight_tie_memory_model = WEIGHT_TIE_MEMORY_MODEL,
92
+ neural_memory_add_value_residual = NEURAL_MEM_ADD_VALUE_RESIDUAL,
93
+ neural_memory_model = MemoryMLP(
94
+ dim = 64,
95
+ depth = NEURAL_MEMORY_DEPTH
96
+ ),
87
97
  neural_memory_kwargs = dict(
88
98
  dim_head = 64,
89
99
  heads = 4,
90
100
  attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
91
101
  momentum = NEURAL_MEM_MOMENTUM,
92
102
  use_accelerated_scan = USE_ACCELERATED_SCAN,
93
- default_model_kwargs = dict(
94
- depth = NEURAL_MEMORY_DEPTH,
95
- )
103
+ per_parameter_lr_modulation = MEMORY_MODEL_PER_LAYER_LEARNED_LR
96
104
  )
97
105
  ).cuda()
98
106
 
File without changes