titans-pytorch 0.1.22__tar.gz → 0.1.23__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.22
3
+ Version: 0.1.23
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -78,7 +78,7 @@ from titans_pytorch import NeuralMemory
78
78
 
79
79
  mem = NeuralMemory(
80
80
  dim = 384,
81
- chunk_size = 64
81
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
82
82
  ).cuda()
83
83
 
84
84
  seq = torch.randn(2, 1024, 384).cuda()
@@ -24,7 +24,7 @@ from titans_pytorch import NeuralMemory
24
24
 
25
25
  mem = NeuralMemory(
26
26
  dim = 384,
27
- chunk_size = 64
27
+ chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
28
28
  ).cuda()
29
29
 
30
30
  seq = torch.randn(2, 1024, 384).cuda()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.1.22"
3
+ version = "0.1.23"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -8,6 +8,9 @@ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, M
8
8
  def exists(v):
9
9
  return v is not None
10
10
 
11
+ def diff(x, y):
12
+ return (x - y).abs().amax()
13
+
11
14
  @pytest.mark.parametrize('seq_len', (32, 1024, 77))
12
15
  @pytest.mark.parametrize('silu', (False, True))
13
16
  @pytest.mark.parametrize('learned_mem_model_weights', (False, True))
@@ -133,6 +136,39 @@ def test_mac_sampling(sliding):
133
136
 
134
137
  assert torch.allclose(sampled, sampled_with_cache)
135
138
 
139
+ @pytest.mark.parametrize('seq_len', (2, 64))
140
+ def test_neural_mem_inference(
141
+ seq_len
142
+ ):
143
+ mem = NeuralMemory(
144
+ dim = 384,
145
+ chunk_size = 64,
146
+ )
147
+
148
+ seq = torch.randn(2, seq_len, 384)
149
+ parallel_retrieved = mem(seq)
150
+
151
+ assert seq.shape == parallel_retrieved.shape
152
+
153
+ mem_model_state = None
154
+ cache_store_seq = None
155
+ sequential_retrieved = []
156
+
157
+ for ind, token in enumerate(seq.unbind(dim = 1)):
158
+
159
+ one_retrieved, cache_store_seq, mem_model_state = mem.forward_inference(
160
+ token,
161
+ seq_index = ind,
162
+ cache_store_seq = cache_store_seq,
163
+ mem_model_state = mem_model_state
164
+ )
165
+
166
+ sequential_retrieved.append(one_retrieved)
167
+
168
+ sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
169
+
170
+ assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)
171
+
136
172
  @pytest.mark.parametrize('seq_len', (1023, 17))
137
173
  @pytest.mark.parametrize('sliding', (True, False))
138
174
  def test_flex(
@@ -157,3 +193,23 @@ def test_flex(
157
193
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
158
194
 
159
195
  assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
196
+
197
+ def test_assoc_scan():
198
+ from titans_pytorch.titans import AssocScan
199
+ import torch.nn.functional as F
200
+
201
+ scan = AssocScan()
202
+
203
+ gates = torch.randn(2, 1024, 512).sigmoid()
204
+ inputs = torch.randn(2, 1024, 512)
205
+
206
+ output = scan(gates, inputs)
207
+
208
+ gates1, gates2 = gates[:, :512], gates[:, 512:]
209
+ inputs1, inputs2 = inputs[:, :512], inputs[:, 512:]
210
+
211
+ first_half = scan(gates1, inputs1)
212
+
213
+ second_half = scan(gates2, inputs2, prev = inputs2[:, -1])
214
+
215
+ assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
@@ -44,6 +44,16 @@ def default(v, d):
44
44
  def xnor(x, y):
45
45
  return not (x ^ y)
46
46
 
47
+ def safe_cat(inputs, dim = -2):
48
+ inputs = tuple(filter(exists, inputs))
49
+
50
+ if len(inputs) == 0:
51
+ return None
52
+ elif len(inputs) == 1:
53
+ return inputs[0]
54
+
55
+ return cat(inputs, dim = dim)
56
+
47
57
  def identity(t):
48
58
  return t
49
59
 
@@ -314,7 +324,11 @@ class AssocScan(Module):
314
324
  super().__init__()
315
325
  self.use_accelerated = use_accelerated
316
326
 
317
- def forward(self, gates, inputs):
327
+ def forward(self, gates, inputs, prev = None):
328
+
329
+ if exists(prev):
330
+ inputs, _ = pack([prev, inputs], 'b * d')
331
+ gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
318
332
 
319
333
  if not self.use_accelerated:
320
334
  _, outputs = associative_scan(binary_operator, (gates, inputs))
@@ -678,7 +692,7 @@ class NeuralMemory(Module):
678
692
  def retrieve_memories(
679
693
  self,
680
694
  seq,
681
- past_weights: dict[str, Tensor] | None = None,
695
+ past_weights: dict[str, Tensor],
682
696
  chunk_size = None
683
697
  ):
684
698
  chunk_size = default(chunk_size, self.retrieve_chunk_size)
@@ -700,13 +714,7 @@ class NeuralMemory(Module):
700
714
  # the parameters of the memory model stores the memories of the key / values
701
715
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
702
716
 
703
- curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
704
-
705
- if exists(past_weights):
706
- past_weights = TensorDict(past_weights)
707
- assert past_weights.keys() == curr_weights.keys()
708
-
709
- curr_weights = curr_weights + past_weights
717
+ curr_weights = TensorDict(past_weights)
710
718
 
711
719
  # sequence Float['b n d'] to queries
712
720
 
@@ -753,6 +761,56 @@ class NeuralMemory(Module):
753
761
 
754
762
  return values[:, :seq_len]
755
763
 
764
+ def forward_inference(
765
+ self,
766
+ token: Tensor,
767
+ seq_index = None, # the index of the token in the sequence, starts at 0
768
+ mem_model_state = None,
769
+ cache_store_seq = None
770
+ ):
771
+ seq_index = default(seq_index, 0)
772
+ curr_seq_len = seq_index + 1
773
+ batch = token.shape[0]
774
+
775
+ if token.ndim == 2:
776
+ token = rearrange(token, 'b d -> b 1 d')
777
+
778
+ # init memory model if needed
779
+
780
+ if not exists(mem_model_state):
781
+ mem_model_state = self.init_weights_and_momentum()
782
+
783
+ # increment the sequence cache which is at most the chunk size
784
+
785
+ cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
786
+
787
+ # early return empty memory, when no memories are stored for steps < first chunk size
788
+
789
+ if curr_seq_len < self.chunk_size:
790
+ empty_mem = self.init_empty_memory_embed(batch, 1)
791
+
792
+ return empty_mem, cache_store_seq, mem_model_state
793
+
794
+ # store if storage sequence cache hits the chunk size
795
+
796
+ store_seq_cache_len = cache_store_seq.shape[-2]
797
+
798
+ if store_seq_cache_len == self.chunk_size:
799
+ updates, _ = self.store_memories(cache_store_seq, mem_model_state)
800
+
801
+ past_weights, past_momentum = mem_model_state
802
+ mem_model_state = (past_weights + updates, past_momentum)
803
+
804
+ cache_store_seq = None
805
+
806
+ # retrieve
807
+
808
+ past_weights, _ = mem_model_state
809
+
810
+ retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
811
+
812
+ return retrieved, cache_store_seq, mem_model_state
813
+
756
814
  def forward(
757
815
  self,
758
816
  seq,
File without changes