titans-pytorch 0.1.30__tar.gz → 0.1.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  train_local.py
2
+ .DS_Store
2
3
 
3
4
  # Byte-compiled / optimized / DLL files
4
5
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.30
3
+ Version: 0.1.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -82,7 +82,7 @@ mem = NeuralMemory(
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
85
- retrieved = mem(seq)
85
+ retrieved, mem_state = mem(seq)
86
86
 
87
87
  assert seq.shape == retrieved.shape
88
88
  ```
@@ -28,7 +28,7 @@ mem = NeuralMemory(
28
28
  ).cuda()
29
29
 
30
30
  seq = torch.randn(2, 1024, 384).cuda()
31
- retrieved = mem(seq)
31
+ retrieved, mem_state = mem(seq)
32
32
 
33
33
  assert seq.shape == retrieved.shape
34
34
  ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.30"
3
+ version = "0.1.32"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -52,12 +52,12 @@ def test_titans(
52
52
  )
53
53
 
54
54
  seq = torch.randn(2, seq_len, 384)
55
- retrieved = mem(seq)
55
+ retrieved, _ = mem(seq)
56
56
 
57
57
  assert seq.shape == retrieved.shape
58
58
 
59
59
  def test_titans_attn_memory():
60
- from titans_pytorch.titans import MemoryAttention
60
+ from titans_pytorch.neural_memory import MemoryAttention
61
61
 
62
62
  mem = NeuralMemory(
63
63
  dim = 384,
@@ -68,7 +68,7 @@ def test_titans_attn_memory():
68
68
  )
69
69
 
70
70
  seq = torch.randn(2, 1024, 384)
71
- retrieved = mem(seq)
71
+ retrieved, _ = mem(seq)
72
72
 
73
73
  assert seq.shape == retrieved.shape
74
74
 
@@ -81,7 +81,7 @@ def test_retrieve_store_diff_seq():
81
81
  retrieve_seq = torch.randn(2, 64 * 64, 384)
82
82
  store_seq = torch.randn(2, 64 * 32, 384)
83
83
 
84
- retrieved = mem(retrieve_seq, store_seq = store_seq)
84
+ retrieved, _ = mem(retrieve_seq, store_seq = store_seq)
85
85
 
86
86
  assert retrieve_seq.shape == retrieved.shape
87
87
 
@@ -94,7 +94,7 @@ def test_overriding_chunk_size():
94
94
  seq = torch.randn(2, 128 * 16, 384)
95
95
  store_seq = torch.randn(2, 128 * 8, 384)
96
96
 
97
- retrieved = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
97
+ retrieved, _ = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
98
98
 
99
99
  assert seq.shape == retrieved.shape
100
100
 
@@ -124,18 +124,22 @@ def test_mac(
124
124
  assert logits.shape == (1, seq_len, 256)
125
125
 
126
126
  @pytest.mark.parametrize('sliding', (False, True))
127
- @pytest.mark.parametrize('mem_layers', ((), None, (4,)))
127
+ @pytest.mark.parametrize('mem_layers', (()))
128
+ @pytest.mark.parametrize('longterm_mems', (0, 4, 16))
129
+ @pytest.mark.parametrize('prompt_len', (0, 4, 16))
128
130
  def test_mac_sampling(
129
131
  sliding,
130
- mem_layers
132
+ mem_layers,
133
+ longterm_mems,
134
+ prompt_len
131
135
  ):
132
136
  transformer = MemoryAsContextTransformer(
133
137
  num_tokens = 256,
134
138
  dim = 256,
135
- depth = 2,
139
+ depth = 4,
136
140
  segment_len = 32,
137
141
  num_persist_mem_tokens = 4,
138
- num_longterm_mem_tokens = 0,
142
+ num_longterm_mem_tokens = longterm_mems,
139
143
  sliding_window_attn = sliding,
140
144
  neural_memory_layers = mem_layers,
141
145
  neural_mem_gate_attn_output = False
@@ -145,29 +149,46 @@ def test_mac_sampling(
145
149
 
146
150
  # after much training
147
151
 
148
- sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
149
- sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
152
+ prompt = ids[:, :prompt_len]
153
+
154
+ sampled = transformer.sample(prompt, 53, use_cache = False, temperature = 0.)
155
+ sampled_with_cache = transformer.sample(prompt, 53, use_cache = True, temperature = 0.)
150
156
 
151
157
  assert torch.allclose(sampled, sampled_with_cache)
152
158
 
153
159
  @pytest.mark.parametrize('seq_len', (2, 64, 256))
160
+ @pytest.mark.parametrize('prompt_len', (0, 65))
161
+ @pytest.mark.parametrize('mem_chunk_size', (2, 32, 64))
154
162
  @torch_default_dtype(torch.float64)
155
163
  def test_neural_mem_inference(
156
- seq_len
164
+ seq_len,
165
+ prompt_len,
166
+ mem_chunk_size
157
167
  ):
158
168
  mem = NeuralMemory(
159
169
  dim = 384,
160
- chunk_size = 64,
170
+ chunk_size = mem_chunk_size,
161
171
  )
162
172
 
163
173
  seq = torch.randn(2, seq_len, 384)
164
- parallel_retrieved = mem(seq)
174
+ parallel_retrieved, _ = mem(seq)
165
175
 
166
176
  assert seq.shape == parallel_retrieved.shape
167
177
 
168
178
  state = None
169
179
  sequential_retrieved = []
170
180
 
181
+ # test initial parallel prompt
182
+
183
+ test_parallel_prompt = prompt_len > 0 and prompt_len < seq_len
184
+
185
+ if test_parallel_prompt:
186
+ prompt, seq = seq[:, :prompt_len], seq[:, prompt_len:]
187
+ retrieved_prompt, state = mem(prompt)
188
+ sequential_retrieved.append(retrieved_prompt)
189
+
190
+ # sequential inference
191
+
171
192
  for token in seq.unbind(dim = 1):
172
193
 
173
194
  one_retrieved, state = mem.forward_inference(
@@ -208,7 +229,7 @@ def test_flex(
208
229
 
209
230
  @torch_default_dtype(torch.float64)
210
231
  def test_assoc_scan():
211
- from titans_pytorch.titans import AssocScan
232
+ from titans_pytorch.neural_memory import AssocScan
212
233
  torch.set_default_dtype(torch.float64)
213
234
 
214
235
  scan = AssocScan()
@@ -1,4 +1,4 @@
1
- from titans_pytorch.titans import (
1
+ from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
65
65
 
66
66
  # proposed neural memory
67
67
 
68
- from titans_pytorch.titans import NeuralMemory
68
+ from titans_pytorch.neural_memory import NeuralMemory
69
69
 
70
70
  # constants
71
71
 
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
106
106
  zeros = ((0, 0) * dims_from_right)
107
107
  return F.pad(t, (*zeros, *pad), value = value)
108
108
 
109
- def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
109
+ def pad_and_segment_with_inverse(
110
+ seq,
111
+ segment_len,
112
+ fold_into_batch = True,
113
+ ):
110
114
  batch, seq_len = seq.shape[:2]
111
115
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
112
116
 
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
119
123
  if fold_into_batch:
120
124
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
121
125
 
122
- def inverse(out, remove_pad = True):
126
+ shape = seq.shape
127
+
128
+ def inverse(out):
129
+ unchanged_shape = out.shape == shape
130
+
123
131
  if fold_into_batch:
124
132
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
125
133
 
126
- if needs_pad and remove_pad:
134
+ if needs_pad and unchanged_shape:
127
135
  out = out[..., :-padding, :]
128
136
 
129
137
  return out
@@ -582,13 +590,8 @@ class MemoryAsContextTransformer(Module):
582
590
  self,
583
591
  seq_index
584
592
  ):
585
- total_segment_len = self.attn_window_size
586
-
587
- seq = seq_index + 1
588
- seq -= int((seq % total_segment_len) == 0)
589
- last_segment_len = round_down_multiple(seq, total_segment_len)
590
- segment_seq = seq - last_segment_len
591
- return (segment_seq - self.segment_len) > 0
593
+ total_segment_len, segment_len = self.attn_window_size, self.segment_len
594
+ return ((seq_index % total_segment_len + 1) - segment_len) > 0
592
595
 
593
596
  def seq_len_with_longterm_mem(
594
597
  self,
@@ -597,7 +600,7 @@ class MemoryAsContextTransformer(Module):
597
600
  assert seq_len > 0
598
601
 
599
602
  segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
600
- return ceil(seq_len / segment_len) * num_mem + seq_len
603
+ return ((seq_len - 1) // segment_len) * num_mem + seq_len
601
604
 
602
605
  @torch.no_grad()
603
606
  def sample(
@@ -695,7 +698,7 @@ class MemoryAsContextTransformer(Module):
695
698
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
696
699
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
697
700
 
698
- x = inverse_segment(x, remove_pad = False)
701
+ x = inverse_segment(x)
699
702
 
700
703
  # splice out unneeded tokens from padding for longterm mems
701
704
 
@@ -723,9 +726,9 @@ class MemoryAsContextTransformer(Module):
723
726
  is_inferencing = exists(cache)
724
727
 
725
728
  if not exists(cache):
726
- cache = (None, None)
729
+ cache = (seq_len_with_mem - 1, None, None)
727
730
 
728
- kv_caches, neural_mem_caches = cache
731
+ inference_seq_index, kv_caches, neural_mem_caches = cache
729
732
 
730
733
  kv_caches = iter(default(kv_caches, []))
731
734
  neural_mem_caches = iter(default(neural_mem_caches, []))
@@ -744,7 +747,8 @@ class MemoryAsContextTransformer(Module):
744
747
  # when inferencing, only do one token at a time
745
748
 
746
749
  if is_inferencing:
747
- x = x[:, -1:]
750
+ ind = inference_seq_index
751
+ x = x[:, ind:(ind + 1)]
748
752
 
749
753
  # expand and reduce streams for hyper connections
750
754
 
@@ -763,14 +767,13 @@ class MemoryAsContextTransformer(Module):
763
767
  mem_input, add_residual = mem_hyper_conn(x)
764
768
 
765
769
  if not is_inferencing:
766
- retrieved, mem_kv_aux_loss = mem(
770
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
767
771
  mem_input,
768
772
  return_aux_kv_loss = True
769
773
  )
770
774
 
771
775
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
772
776
 
773
- next_neural_mem_cache = (seq_len, None, None, None)
774
777
  else:
775
778
  retrieved, next_neural_mem_cache = mem.forward_inference(
776
779
  mem_input,
@@ -817,6 +820,17 @@ class MemoryAsContextTransformer(Module):
817
820
  if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
818
821
  next_kv_caches = next_kv_caches[..., 0:0, :]
819
822
 
823
+ next_cache = (
824
+ inference_seq_index + 1,
825
+ next_kv_caches,
826
+ next_neural_mem_caches
827
+ )
828
+
829
+ is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
830
+
831
+ if is_inferencing and is_longterm_mem:
832
+ return None, next_cache
833
+
820
834
  # hyper connection reducing of streams
821
835
 
822
836
  x = self.reduce_streams(x)
@@ -843,7 +857,7 @@ class MemoryAsContextTransformer(Module):
843
857
  if not return_cache:
844
858
  return logits
845
859
 
846
- return logits, (next_kv_caches, next_neural_mem_caches)
860
+ return logits, next_cache
847
861
 
848
862
  ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
849
863
 
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
38
38
  def exists(v):
39
39
  return v is not None
40
40
 
41
- def default(v, d):
42
- return v if exists(v) else d
41
+ def default(*args):
42
+ for arg in args:
43
+ if exists(arg):
44
+ return arg
45
+ return None
43
46
 
44
47
  def xnor(x, y):
45
48
  return not (x ^ y)
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
468
471
  weighted_loss = loss * loss_weights
469
472
  return weighted_loss.sum(), weighted_loss.mean()
470
473
 
471
- self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
474
+ # two functions
475
+
476
+ grad_fn = grad(forward_and_loss, has_aux = True)
477
+
478
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
479
+ self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
472
480
 
473
481
  # queries for retrieving from the model
474
482
 
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
561
569
  seq,
562
570
  weights: dict[str, Tensor],
563
571
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
572
+ prev_layer_updates: dict[str, Tensor] | None = None,
564
573
  return_aux_kv_loss = False,
565
574
  chunk_size = None,
566
575
  value_residual = None
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
583
592
 
584
593
  seq = seq[:, :round_down_seq_len]
585
594
 
595
+ # per sample grad function
596
+
597
+ per_sample_grad_fn = self.per_sample_grad_fn
598
+
586
599
  # weights of the memory network
587
600
 
588
601
  weights = TensorDict(weights)
589
602
 
603
+ # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
604
+ # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
605
+ # improvise (or perhaps correcting to) a solution
606
+
607
+ if exists(prev_layer_updates):
608
+ prev_layer_updates = TensorDict(weights)
609
+
610
+ weights = weights + prev_layer_updates
611
+
612
+ per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
613
+
590
614
  # derive learned hparams for optimization of memory network
591
615
 
592
616
  adaptive_lr = self.to_adaptive_step(seq)
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
635
659
 
636
660
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
637
661
 
638
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
662
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
639
663
 
640
664
  grads = TensorDict(grads)
641
665
 
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
781
805
 
782
806
  return values[:, :seq_len]
783
807
 
808
+ @torch.no_grad()
784
809
  def forward_inference(
785
810
  self,
786
811
  token: Tensor,
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
854
879
  return_aux_kv_loss = False,
855
880
  chunk_size = None,
856
881
  store_chunk_size = None,
857
- return_values = False
882
+ return_values = False,
883
+ return_next_state = False
858
884
  ):
859
885
  batch, seq_len = seq.shape[:2]
860
886
 
861
887
  if seq_len < self.retrieve_chunk_size:
862
888
  out = self.init_empty_memory_embed(batch, seq_len)
863
889
 
890
+ next_store_state = (seq_len, seq, None, None)
891
+
892
+ out = (out, next_store_state)
893
+
864
894
  if not return_aux_kv_loss:
865
895
  return out
866
896
 
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
870
900
  mem_model_weights = self.init_weights()
871
901
 
872
902
  store_seq = default(store_seq, seq)
873
- store_chunk_size = default(store_chunk_size, chunk_size)
903
+
904
+ store_seq_len = store_seq.shape[-2]
905
+ store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
+ remainder = store_seq_len % store_chunk_size
874
907
 
875
908
  (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
876
909
 
877
910
  retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
878
911
 
879
- output = retrieved
912
+ # determine state for the storing of memories
913
+ # for transformer-xl like training with neural memory as well as inferencing with initial prompt
914
+
915
+ cache_store_seq = None
916
+
917
+ if remainder > 0:
918
+ cache_store_seq = store_seq[:, -remainder:]
919
+
920
+ updates = updates.apply(lambda t: t[:, -1:])
921
+
922
+ next_store_state = (seq_len, cache_store_seq, next_state, updates)
923
+
924
+ output = (retrieved, next_store_state)
880
925
 
881
926
  if return_values:
882
- output = (retrieved, values)
927
+ output = (*output, values)
883
928
 
884
929
  if not return_aux_kv_loss:
885
930
  return output
File without changes