titans-pytorch 0.1.32__tar.gz → 0.1.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.32
3
+ Version: 0.1.33
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.32"
3
+ version = "0.1.33"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -85,6 +85,34 @@ def test_retrieve_store_diff_seq():
85
85
 
86
86
  assert retrieve_seq.shape == retrieved.shape
87
87
 
88
+ def test_weight_tied_mlp_neural_mem():
89
+ mem = NeuralMemory(
90
+ dim = 384,
91
+ dim_head = 64,
92
+ heads = 2,
93
+ chunk_size = 2
94
+ )
95
+
96
+ mem2 = NeuralMemory(
97
+ dim = 384,
98
+ dim_head = 64,
99
+ heads = 2,
100
+ chunk_size = 2
101
+ )
102
+
103
+ mem3 = NeuralMemory(
104
+ dim = 384,
105
+ dim_head = 64,
106
+ heads = 2,
107
+ chunk_size = 2
108
+ )
109
+
110
+ seq = torch.randn(2, 128, 384)
111
+
112
+ seq, cache = mem(seq)
113
+ seq, cache2 = mem2(seq, prev_layer_updates = cache.updates)
114
+ seq, cache3 = mem3(seq, prev_layer_updates = cache2.updates)
115
+
88
116
  def test_overriding_chunk_size():
89
117
  mem = NeuralMemory(
90
118
  dim = 384,
@@ -3,6 +3,7 @@ from typing import Callable
3
3
 
4
4
  import math
5
5
  from functools import partial
6
+ from collections import namedtuple
6
7
 
7
8
  import torch
8
9
  from torch import nn, cat, Tensor
@@ -33,6 +34,8 @@ w - num memory network weight parameters
33
34
 
34
35
  LinearNoBias = partial(Linear, bias = False)
35
36
 
37
+ NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
38
+
36
39
  # functions
37
40
 
38
41
  def exists(v):
@@ -605,7 +608,7 @@ class NeuralMemory(Module):
605
608
  # improvise (or perhaps correcting to) a solution
606
609
 
607
610
  if exists(prev_layer_updates):
608
- prev_layer_updates = TensorDict(weights)
611
+ prev_layer_updates = TensorDict(prev_layer_updates)
609
612
 
610
613
  weights = weights + prev_layer_updates
611
614
 
@@ -657,6 +660,11 @@ class NeuralMemory(Module):
657
660
 
658
661
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
659
662
 
663
+ # flatten batch and time if surprise depends on previous layer memory model
664
+
665
+ if exists(prev_layer_updates):
666
+ weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
667
+
660
668
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
661
669
 
662
670
  grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
@@ -737,7 +745,8 @@ class NeuralMemory(Module):
737
745
  self,
738
746
  seq,
739
747
  past_weights: dict[str, Tensor],
740
- chunk_size = None
748
+ chunk_size = None,
749
+ prev_layer_updates: dict[str, Tensor] | None = None
741
750
  ):
742
751
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
743
752
  batch, seq_len = seq.shape[:2]
@@ -760,6 +769,9 @@ class NeuralMemory(Module):
760
769
 
761
770
  curr_weights = TensorDict(past_weights)
762
771
 
772
+ if exists(prev_layer_updates):
773
+ curr_weights = curr_weights + TensorDict(prev_layer_updates)
774
+
763
775
  # sequence Float['b n d'] to queries
764
776
 
765
777
  queries = self.to_queries(seq)
@@ -838,7 +850,7 @@ class NeuralMemory(Module):
838
850
  if curr_seq_len < self.chunk_size:
839
851
  empty_mem = self.init_empty_memory_embed(batch, 1)
840
852
 
841
- return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
853
+ return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
842
854
 
843
855
  # store if storage sequence cache hits the chunk size
844
856
 
@@ -848,6 +860,8 @@ class NeuralMemory(Module):
848
860
  if not exists(updates):
849
861
  updates = weights.clone().zero_()
850
862
  updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
863
+ else:
864
+ updates = updates.apply(lambda t: t[:, -1:])
851
865
 
852
866
  if store_seq_cache_len == self.chunk_size:
853
867
 
@@ -866,7 +880,7 @@ class NeuralMemory(Module):
866
880
 
867
881
  # next state tuple
868
882
 
869
- next_state = (curr_seq_len, cache_store_seq, next_states, updates)
883
+ next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
870
884
 
871
885
  return retrieved, next_state
872
886
 
@@ -880,7 +894,8 @@ class NeuralMemory(Module):
880
894
  chunk_size = None,
881
895
  store_chunk_size = None,
882
896
  return_values = False,
883
- return_next_state = False
897
+ return_next_state = False,
898
+ prev_layer_updates: dict[str, Tensor] | None = None
884
899
  ):
885
900
  batch, seq_len = seq.shape[:2]
886
901
 
@@ -899,15 +914,30 @@ class NeuralMemory(Module):
899
914
  if not exists(mem_model_weights):
900
915
  mem_model_weights = self.init_weights()
901
916
 
917
+ # store
918
+
902
919
  store_seq = default(store_seq, seq)
903
920
 
904
921
  store_seq_len = store_seq.shape[-2]
905
922
  store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
923
  remainder = store_seq_len % store_chunk_size
907
924
 
908
- (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
925
+ (updates, next_state, values), aux_kv_recon_loss = self.store_memories(
926
+ store_seq,
927
+ mem_model_weights,
928
+ chunk_size = store_chunk_size,
929
+ prev_layer_updates = prev_layer_updates,
930
+ return_aux_kv_loss = True
931
+ )
909
932
 
910
- retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
933
+ # retrieve
934
+
935
+ retrieved = self.retrieve_memories(
936
+ seq,
937
+ mem_model_weights + updates,
938
+ chunk_size = chunk_size,
939
+ prev_layer_updates = prev_layer_updates
940
+ )
911
941
 
912
942
  # determine state for the storing of memories
913
943
  # for transformer-xl like training with neural memory as well as inferencing with initial prompt
@@ -917,9 +947,7 @@ class NeuralMemory(Module):
917
947
  if remainder > 0:
918
948
  cache_store_seq = store_seq[:, -remainder:]
919
949
 
920
- updates = updates.apply(lambda t: t[:, -1:])
921
-
922
- next_store_state = (seq_len, cache_store_seq, next_state, updates)
950
+ next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
923
951
 
924
952
  output = (retrieved, next_store_state)
925
953
 
File without changes