titans-pytorch 0.1.21__tar.gz → 0.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.21
3
+ Version: 0.1.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
78
78
 
79
79
  mem = NeuralMemory(
80
80
  dim = 384,
81
- chunk_size = 64,
82
- pre_rmsnorm = True
81
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
83
82
  ).cuda()
84
83
 
85
84
  seq = torch.randn(2, 1024, 384).cuda()
@@ -24,8 +24,7 @@ from titans_pytorch import NeuralMemory
24
24
 
25
25
  mem = NeuralMemory(
26
26
  dim = 384,
27
- chunk_size = 64,
28
- pre_rmsnorm = True
27
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
29
28
  ).cuda()
30
29
 
31
30
  seq = torch.randn(2, 1024, 384).cuda()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.21"
3
+ version = "0.1.23"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -8,6 +8,9 @@ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, M
8
8
  def exists(v):
9
9
  return v is not None
10
10
 
11
+ def diff(x, y):
12
+ return (x - y).abs().amax()
13
+
11
14
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
15
  @pytest.mark.parametrize('silu', (False, True))
13
16
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
@@ -133,6 +136,39 @@ def test_mac_sampling(sliding):
133
136
 
134
137
  assert torch.allclose(sampled, sampled_with_cache)
135
138
 
139
+ @pytest.mark.parametrize('seq_len', (2, 64))
140
+ def test_neural_mem_inference(
141
+ seq_len
142
+ ):
143
+ mem = NeuralMemory(
144
+ dim = 384,
145
+ chunk_size = 64,
146
+ )
147
+
148
+ seq = torch.randn(2, seq_len, 384)
149
+ parallel_retrieved = mem(seq)
150
+
151
+ assert seq.shape == parallel_retrieved.shape
152
+
153
+ mem_model_state = None
154
+ cache_store_seq = None
155
+ sequential_retrieved = []
156
+
157
+ for ind, token in enumerate(seq.unbind(dim = 1)):
158
+
159
+ one_retrieved, cache_store_seq, mem_model_state = mem.forward_inference(
160
+ token,
161
+ seq_index = ind,
162
+ cache_store_seq = cache_store_seq,
163
+ mem_model_state = mem_model_state
164
+ )
165
+
166
+ sequential_retrieved.append(one_retrieved)
167
+
168
+ sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
169
+
170
+ assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)
171
+
136
172
  @pytest.mark.parametrize('seq_len', (1023, 17))
137
173
  @pytest.mark.parametrize('sliding', (True, False))
138
174
  def test_flex(
@@ -157,3 +193,23 @@ def test_flex(
157
193
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
158
194
 
159
195
  assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
196
+
197
+ def test_assoc_scan():
198
+ from titans_pytorch.titans import AssocScan
199
+ import torch.nn.functional as F
200
+
201
+ scan = AssocScan()
202
+
203
+ gates = torch.randn(2, 1024, 512).sigmoid()
204
+ inputs = torch.randn(2, 1024, 512)
205
+
206
+ output = scan(gates, inputs)
207
+
208
+ gates1, gates2 = gates[:, :512], gates[:, 512:]
209
+ inputs1, inputs2 = inputs[:, :512], inputs[:, 512:]
210
+
211
+ first_half = scan(gates1, inputs1)
212
+
213
+ second_half = scan(gates2, inputs2, prev = inputs2[:, -1])
214
+
215
+ assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
@@ -41,6 +41,19 @@ def exists(v):
41
41
  def default(v, d):
42
42
  return v if exists(v) else d
43
43
 
44
+ def xnor(x, y):
45
+ return not (x ^ y)
46
+
47
+ def safe_cat(inputs, dim = -2):
48
+ inputs = tuple(filter(exists, inputs))
49
+
50
+ if len(inputs) == 0:
51
+ return None
52
+ elif len(inputs) == 1:
53
+ return inputs[0]
54
+
55
+ return cat(inputs, dim = dim)
56
+
44
57
  def identity(t):
45
58
  return t
46
59
 
@@ -311,7 +324,11 @@ class AssocScan(Module):
311
324
  super().__init__()
312
325
  self.use_accelerated = use_accelerated
313
326
 
314
- def forward(self, gates, inputs):
327
+ def forward(self, gates, inputs, prev = None):
328
+
329
+ if exists(prev):
330
+ inputs, _ = pack([prev, inputs], 'b * d')
331
+ gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
315
332
 
316
333
  if not self.use_accelerated:
317
334
  _, outputs = associative_scan(binary_operator, (gates, inputs))
@@ -366,6 +383,7 @@ class NeuralMemory(Module):
366
383
  pre_rmsnorm = True,
367
384
  post_rmsnorm = True,
368
385
  qk_rmsnorm = False,
386
+ accept_value_residual = False,
369
387
  learned_mem_model_weights = True,
370
388
  max_grad_norm: float | None = None,
371
389
  use_accelerated_scan = False,
@@ -399,7 +417,7 @@ class NeuralMemory(Module):
399
417
 
400
418
  self.heads = heads
401
419
 
402
- self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
420
+ self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
403
421
  self.merge_heads = Rearrange('b h n d -> b n (h d)')
404
422
  self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
405
423
 
@@ -448,6 +466,14 @@ class NeuralMemory(Module):
448
466
  self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
449
467
  self.store_memory_loss_fn = store_memory_loss_fn
450
468
 
469
+ # value residual learning
470
+
471
+ self.learned_value_residual = Sequential(
472
+ LinearNoBias(dim, heads),
473
+ Rearrange('b n h -> b h n 1'),
474
+ nn.Sigmoid()
475
+ ) if accept_value_residual else None
476
+
451
477
  # empty memory embed
452
478
 
453
479
  self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
@@ -529,8 +555,11 @@ class NeuralMemory(Module):
529
555
  seq,
530
556
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
531
557
  return_aux_kv_loss = False,
532
- chunk_size = None
558
+ chunk_size = None,
559
+ value_residual = None
533
560
  ):
561
+ assert xnor(exists(value_residual), exists(self.learned_value_residual))
562
+
534
563
  seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
535
564
 
536
565
  # handle edge case
@@ -585,9 +614,17 @@ class NeuralMemory(Module):
585
614
 
586
615
  keys = self.k_norm(keys)
587
616
 
617
+ # maybe value residual learning
618
+
619
+ orig_values = values
620
+
621
+ if exists(self.learned_value_residual):
622
+ mix = self.learned_value_residual(seq)
623
+ values = values.lerp(value_residual, mix)
624
+
588
625
  # take care of chunking
589
626
 
590
- keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
627
+ keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
591
628
 
592
629
  adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
593
630
 
@@ -645,15 +682,17 @@ class NeuralMemory(Module):
645
682
 
646
683
  last_update = updates.apply(lambda t: t[:, -1])
647
684
 
685
+ output = (updates, orig_values)
686
+
648
687
  if not return_aux_kv_loss:
649
- return updates
688
+ return output
650
689
 
651
- return updates, aux_kv_recon_loss.mean()
690
+ return output, aux_kv_recon_loss.mean()
652
691
 
653
692
  def retrieve_memories(
654
693
  self,
655
694
  seq,
656
- past_weights: dict[str, Tensor] | None = None,
695
+ past_weights: dict[str, Tensor],
657
696
  chunk_size = None
658
697
  ):
659
698
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
@@ -675,13 +714,7 @@ class NeuralMemory(Module):
675
714
  # the parameters of the memory model stores the memories of the key / values
676
715
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
677
716
 
678
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
679
-
680
- if exists(past_weights):
681
- past_weights = TensorDict(past_weights)
682
- assert past_weights.keys() == curr_weights.keys()
683
-
684
- curr_weights = curr_weights + past_weights
717
+ curr_weights = TensorDict(past_weights)
685
718
 
686
719
  # sequence Float['b n d'] to queries
687
720
 
@@ -698,7 +731,7 @@ class NeuralMemory(Module):
698
731
  # fetch values from memory model
699
732
 
700
733
  curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
701
- queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
734
+ queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
702
735
 
703
736
  # forward functional call
704
737
 
@@ -728,6 +761,56 @@ class NeuralMemory(Module):
728
761
 
729
762
  return values[:, :seq_len]
730
763
 
764
+ def forward_inference(
765
+ self,
766
+ token: Tensor,
767
+ seq_index = None, # the index of the token in the sequence, starts at 0
768
+ mem_model_state = None,
769
+ cache_store_seq = None
770
+ ):
771
+ seq_index = default(seq_index, 0)
772
+ curr_seq_len = seq_index + 1
773
+ batch = token.shape[0]
774
+
775
+ if token.ndim == 2:
776
+ token = rearrange(token, 'b d -> b 1 d')
777
+
778
+ # init memory model if needed
779
+
780
+ if not exists(mem_model_state):
781
+ mem_model_state = self.init_weights_and_momentum()
782
+
783
+ # increment the sequence cache which is at most the chunk size
784
+
785
+ cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
786
+
787
+ # early return empty memory, when no memories are stored for steps < first chunk size
788
+
789
+ if curr_seq_len < self.chunk_size:
790
+ empty_mem = self.init_empty_memory_embed(batch, 1)
791
+
792
+ return empty_mem, cache_store_seq, mem_model_state
793
+
794
+ # store if storage sequence cache hits the chunk size
795
+
796
+ store_seq_cache_len = cache_store_seq.shape[-2]
797
+
798
+ if store_seq_cache_len == self.chunk_size:
799
+ updates, _ = self.store_memories(cache_store_seq, mem_model_state)
800
+
801
+ past_weights, past_momentum = mem_model_state
802
+ mem_model_state = (past_weights + updates, past_momentum)
803
+
804
+ cache_store_seq = None
805
+
806
+ # retrieve
807
+
808
+ past_weights, _ = mem_model_state
809
+
810
+ retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
811
+
812
+ return retrieved, cache_store_seq, mem_model_state
813
+
731
814
  def forward(
732
815
  self,
733
816
  seq,
@@ -735,7 +818,8 @@ class NeuralMemory(Module):
735
818
  past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
736
819
  return_aux_kv_loss = False,
737
820
  chunk_size = None,
738
- store_chunk_size = None
821
+ store_chunk_size = None,
822
+ return_values = False
739
823
  ):
740
824
  batch, seq_len = seq.shape[:2]
741
825
 
@@ -756,13 +840,18 @@ class NeuralMemory(Module):
756
840
  store_seq = default(store_seq, seq)
757
841
  store_chunk_size = default(store_chunk_size, chunk_size)
758
842
 
759
- updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
843
+ (updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
760
844
 
761
845
  past_weights, _ = past_state
762
846
 
763
847
  retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
764
848
 
849
+ output = retrieved
850
+
851
+ if return_values:
852
+ output = (retrieved, values)
853
+
765
854
  if not return_aux_kv_loss:
766
- return retrieved
855
+ return output
767
856
 
768
- return retrieved, aux_kv_recon_loss
857
+ return output, aux_kv_recon_loss
File without changes