titans-pytorch 0.1.31__tar.gz → 0.1.32__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  train_local.py
2
+ .DS_Store
2
3
 
3
4
  # Byte-compiled / optimized / DLL files
4
5
  __pycache__/
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.31
3
+ Version: 0.1.32
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -82,7 +82,7 @@ mem = NeuralMemory(
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
85
- retrieved = mem(seq)
85
+ retrieved, mem_state = mem(seq)
86
86
 
87
87
  assert seq.shape == retrieved.shape
88
88
  ```
@@ -28,7 +28,7 @@ mem = NeuralMemory(
28
28
  ).cuda()
29
29
 
30
30
  seq = torch.randn(2, 1024, 384).cuda()
31
- retrieved = mem(seq)
31
+ retrieved, mem_state = mem(seq)
32
32
 
33
33
  assert seq.shape == retrieved.shape
34
34
  ```
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.31"
3
+ version = "0.1.32"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -52,12 +52,12 @@ def test_titans(
52
52
  )
53
53
 
54
54
  seq = torch.randn(2, seq_len, 384)
55
- retrieved = mem(seq)
55
+ retrieved, _ = mem(seq)
56
56
 
57
57
  assert seq.shape == retrieved.shape
58
58
 
59
59
  def test_titans_attn_memory():
60
- from titans_pytorch.titans import MemoryAttention
60
+ from titans_pytorch.neural_memory import MemoryAttention
61
61
 
62
62
  mem = NeuralMemory(
63
63
  dim = 384,
@@ -68,7 +68,7 @@ def test_titans_attn_memory():
68
68
  )
69
69
 
70
70
  seq = torch.randn(2, 1024, 384)
71
- retrieved = mem(seq)
71
+ retrieved, _ = mem(seq)
72
72
 
73
73
  assert seq.shape == retrieved.shape
74
74
 
@@ -81,7 +81,7 @@ def test_retrieve_store_diff_seq():
81
81
  retrieve_seq = torch.randn(2, 64 * 64, 384)
82
82
  store_seq = torch.randn(2, 64 * 32, 384)
83
83
 
84
- retrieved = mem(retrieve_seq, store_seq = store_seq)
84
+ retrieved, _ = mem(retrieve_seq, store_seq = store_seq)
85
85
 
86
86
  assert retrieve_seq.shape == retrieved.shape
87
87
 
@@ -94,7 +94,7 @@ def test_overriding_chunk_size():
94
94
  seq = torch.randn(2, 128 * 16, 384)
95
95
  store_seq = torch.randn(2, 128 * 8, 384)
96
96
 
97
- retrieved = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
97
+ retrieved, _ = mem(seq, store_seq, chunk_size = 16, store_chunk_size = 8)
98
98
 
99
99
  assert seq.shape == retrieved.shape
100
100
 
@@ -157,23 +157,38 @@ def test_mac_sampling(
157
157
  assert torch.allclose(sampled, sampled_with_cache)
158
158
 
159
159
  @pytest.mark.parametrize('seq_len', (2, 64, 256))
160
+ @pytest.mark.parametrize('prompt_len', (0, 65))
161
+ @pytest.mark.parametrize('mem_chunk_size', (2, 32, 64))
160
162
  @torch_default_dtype(torch.float64)
161
163
  def test_neural_mem_inference(
162
- seq_len
164
+ seq_len,
165
+ prompt_len,
166
+ mem_chunk_size
163
167
  ):
164
168
  mem = NeuralMemory(
165
169
  dim = 384,
166
- chunk_size = 64,
170
+ chunk_size = mem_chunk_size,
167
171
  )
168
172
 
169
173
  seq = torch.randn(2, seq_len, 384)
170
- parallel_retrieved = mem(seq)
174
+ parallel_retrieved, _ = mem(seq)
171
175
 
172
176
  assert seq.shape == parallel_retrieved.shape
173
177
 
174
178
  state = None
175
179
  sequential_retrieved = []
176
180
 
181
+ # test initial parallel prompt
182
+
183
+ test_parallel_prompt = prompt_len > 0 and prompt_len < seq_len
184
+
185
+ if test_parallel_prompt:
186
+ prompt, seq = seq[:, :prompt_len], seq[:, prompt_len:]
187
+ retrieved_prompt, state = mem(prompt)
188
+ sequential_retrieved.append(retrieved_prompt)
189
+
190
+ # sequential inference
191
+
177
192
  for token in seq.unbind(dim = 1):
178
193
 
179
194
  one_retrieved, state = mem.forward_inference(
@@ -214,7 +229,7 @@ def test_flex(
214
229
 
215
230
  @torch_default_dtype(torch.float64)
216
231
  def test_assoc_scan():
217
- from titans_pytorch.titans import AssocScan
232
+ from titans_pytorch.neural_memory import AssocScan
218
233
  torch.set_default_dtype(torch.float64)
219
234
 
220
235
  scan = AssocScan()
@@ -1,4 +1,4 @@
1
- from titans_pytorch.titans import (
1
+ from titans_pytorch.neural_memory import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
4
  MemoryAttention,
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
65
65
 
66
66
  # proposed neural memory
67
67
 
68
- from titans_pytorch.titans import NeuralMemory
68
+ from titans_pytorch.neural_memory import NeuralMemory
69
69
 
70
70
  # constants
71
71
 
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
106
106
  zeros = ((0, 0) * dims_from_right)
107
107
  return F.pad(t, (*zeros, *pad), value = value)
108
108
 
109
- def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
109
+ def pad_and_segment_with_inverse(
110
+ seq,
111
+ segment_len,
112
+ fold_into_batch = True,
113
+ ):
110
114
  batch, seq_len = seq.shape[:2]
111
115
  next_seq_len_mult = round_up_multiple(seq_len, segment_len)
112
116
 
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
119
123
  if fold_into_batch:
120
124
  seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
121
125
 
122
- def inverse(out, remove_pad = True):
126
+ shape = seq.shape
127
+
128
+ def inverse(out):
129
+ unchanged_shape = out.shape == shape
130
+
123
131
  if fold_into_batch:
124
132
  out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
125
133
 
126
- if needs_pad and remove_pad:
134
+ if needs_pad and unchanged_shape:
127
135
  out = out[..., :-padding, :]
128
136
 
129
137
  return out
@@ -690,7 +698,7 @@ class MemoryAsContextTransformer(Module):
690
698
  mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
691
699
  x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
692
700
 
693
- x = inverse_segment(x, remove_pad = False)
701
+ x = inverse_segment(x)
694
702
 
695
703
  # splice out unneeded tokens from padding for longterm mems
696
704
 
@@ -759,14 +767,13 @@ class MemoryAsContextTransformer(Module):
759
767
  mem_input, add_residual = mem_hyper_conn(x)
760
768
 
761
769
  if not is_inferencing:
762
- retrieved, mem_kv_aux_loss = mem(
770
+ (retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
763
771
  mem_input,
764
772
  return_aux_kv_loss = True
765
773
  )
766
774
 
767
775
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
768
776
 
769
- next_neural_mem_cache = (seq_len, None, None, None)
770
777
  else:
771
778
  retrieved, next_neural_mem_cache = mem.forward_inference(
772
779
  mem_input,
@@ -836,7 +843,7 @@ class MemoryAsContextTransformer(Module):
836
843
 
837
844
  x, _ = inverse_pack_mems(x)
838
845
 
839
- x = inverse_segment(x, remove_pad = False)
846
+ x = inverse_segment(x)
840
847
 
841
848
  x = x[:, :seq_len]
842
849
 
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
38
38
  def exists(v):
39
39
  return v is not None
40
40
 
41
- def default(v, d):
42
- return v if exists(v) else d
41
+ def default(*args):
42
+ for arg in args:
43
+ if exists(arg):
44
+ return arg
45
+ return None
43
46
 
44
47
  def xnor(x, y):
45
48
  return not (x ^ y)
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
468
471
  weighted_loss = loss * loss_weights
469
472
  return weighted_loss.sum(), weighted_loss.mean()
470
473
 
471
- self.per_sample_grad_fn = vmap(grad(forward_and_loss, has_aux = True), in_dims = (None, 0, 0, 0))
474
+ # two functions
475
+
476
+ grad_fn = grad(forward_and_loss, has_aux = True)
477
+
478
+ self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
479
+ self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
472
480
 
473
481
  # queries for retrieving from the model
474
482
 
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
561
569
  seq,
562
570
  weights: dict[str, Tensor],
563
571
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
572
+ prev_layer_updates: dict[str, Tensor] | None = None,
564
573
  return_aux_kv_loss = False,
565
574
  chunk_size = None,
566
575
  value_residual = None
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
583
592
 
584
593
  seq = seq[:, :round_down_seq_len]
585
594
 
595
+ # per sample grad function
596
+
597
+ per_sample_grad_fn = self.per_sample_grad_fn
598
+
586
599
  # weights of the memory network
587
600
 
588
601
  weights = TensorDict(weights)
589
602
 
603
+ # allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
604
+ # think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
605
+ # improvise (or perhaps correcting to) a solution
606
+
607
+ if exists(prev_layer_updates):
608
+ prev_layer_updates = TensorDict(weights)
609
+
610
+ weights = weights + prev_layer_updates
611
+
612
+ per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
613
+
590
614
  # derive learned hparams for optimization of memory network
591
615
 
592
616
  adaptive_lr = self.to_adaptive_step(seq)
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
635
659
 
636
660
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
637
661
 
638
- grads, aux_kv_recon_loss = self.per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
662
+ grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
639
663
 
640
664
  grads = TensorDict(grads)
641
665
 
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
781
805
 
782
806
  return values[:, :seq_len]
783
807
 
808
+ @torch.no_grad()
784
809
  def forward_inference(
785
810
  self,
786
811
  token: Tensor,
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
854
879
  return_aux_kv_loss = False,
855
880
  chunk_size = None,
856
881
  store_chunk_size = None,
857
- return_values = False
882
+ return_values = False,
883
+ return_next_state = False
858
884
  ):
859
885
  batch, seq_len = seq.shape[:2]
860
886
 
861
887
  if seq_len < self.retrieve_chunk_size:
862
888
  out = self.init_empty_memory_embed(batch, seq_len)
863
889
 
890
+ next_store_state = (seq_len, seq, None, None)
891
+
892
+ out = (out, next_store_state)
893
+
864
894
  if not return_aux_kv_loss:
865
895
  return out
866
896
 
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
870
900
  mem_model_weights = self.init_weights()
871
901
 
872
902
  store_seq = default(store_seq, seq)
873
- store_chunk_size = default(store_chunk_size, chunk_size)
903
+
904
+ store_seq_len = store_seq.shape[-2]
905
+ store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
906
+ remainder = store_seq_len % store_chunk_size
874
907
 
875
908
  (updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
876
909
 
877
910
  retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
878
911
 
879
- output = retrieved
912
+ # determine state for the storing of memories
913
+ # for transformer-xl like training with neural memory as well as inferencing with initial prompt
914
+
915
+ cache_store_seq = None
916
+
917
+ if remainder > 0:
918
+ cache_store_seq = store_seq[:, -remainder:]
919
+
920
+ updates = updates.apply(lambda t: t[:, -1:])
921
+
922
+ next_store_state = (seq_len, cache_store_seq, next_state, updates)
923
+
924
+ output = (retrieved, next_store_state)
880
925
 
881
926
  if return_values:
882
- output = (retrieved, values)
927
+ output = (*output, values)
883
928
 
884
929
  if not return_aux_kv_loss:
885
930
  return output
File without changes