titans-pytorch 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.2"
3
+ version = "0.3.3"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -42,7 +42,7 @@ def test_titans(
42
42
  per_parameter_lr_modulation
43
43
  ):
44
44
  mem = NeuralMemory(
45
- dim = 384,
45
+ dim = 16,
46
46
  chunk_size = chunk_size,
47
47
  activation = nn.SiLU() if silu else None,
48
48
  attn_pool_chunks = attn_pool_chunks,
@@ -52,7 +52,7 @@ def test_titans(
52
52
  per_parameter_lr_modulation = per_parameter_lr_modulation,
53
53
  )
54
54
 
55
- seq = torch.randn(2, seq_len, 384)
55
+ seq = torch.randn(2, seq_len, 16)
56
56
  retrieved, _ = mem(seq)
57
57
 
58
58
  assert seq.shape == retrieved.shape
@@ -61,14 +61,14 @@ def test_titans_attn_memory():
61
61
  from titans_pytorch.memory_models import MemoryAttention
62
62
 
63
63
  mem = NeuralMemory(
64
- dim = 384,
64
+ dim = 16,
65
65
  chunk_size = 64,
66
66
  model = MemoryAttention(
67
- dim = 384
67
+ dim = 16
68
68
  )
69
69
  )
70
70
 
71
- seq = torch.randn(2, 1024, 384)
71
+ seq = torch.randn(2, 1024, 16)
72
72
  retrieved, _ = mem(seq)
73
73
 
74
74
  assert seq.shape == retrieved.shape
@@ -78,14 +78,14 @@ def test_neural_mem_chaining_chunks(
78
78
  gated_transition
79
79
  ):
80
80
  mem = NeuralMemory(
81
- dim = 384,
82
- dim_head = 64,
81
+ dim = 16,
82
+ dim_head = 16,
83
83
  heads = 2,
84
84
  chunk_size = 16,
85
85
  gated_transition = gated_transition
86
86
  )
87
87
 
88
- seq = torch.randn(2, 48, 384)
88
+ seq = torch.randn(2, 48, 16)
89
89
 
90
90
  parallel_retrieved, state = mem(seq)
91
91
 
@@ -99,21 +99,21 @@ def test_neural_mem_chaining_chunks(
99
99
 
100
100
  def test_neural_mem_chaining_with_weight_residual():
101
101
  mem = NeuralMemory(
102
- dim = 384,
103
- dim_head = 64,
102
+ dim = 16,
103
+ dim_head = 16,
104
104
  heads = 2,
105
105
  chunk_size = 64
106
106
  )
107
107
 
108
108
  mem2 = NeuralMemory(
109
- dim = 384,
110
- dim_head = 64,
109
+ dim = 16,
110
+ dim_head = 16,
111
111
  heads = 2,
112
112
  chunk_size = 64,
113
113
  accept_weight_residual = True
114
114
  )
115
115
 
116
- seq = torch.randn(2, 256, 384)
116
+ seq = torch.randn(2, 256, 16)
117
117
 
118
118
  seq, state = mem(seq)
119
119
 
@@ -124,18 +124,18 @@ def test_neural_mem_chaining_with_weight_residual():
124
124
  first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
125
125
  second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
126
126
 
127
- assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
127
+ assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
128
128
 
129
129
  def test_neural_mem_chaining_with_batch_size():
130
130
  mem = NeuralMemory(
131
- dim = 384,
132
- dim_head = 64,
131
+ dim = 16,
132
+ dim_head = 16,
133
133
  heads = 2,
134
134
  chunk_size = 16,
135
135
  batch_size = 64
136
136
  )
137
137
 
138
- seq = torch.randn(2, 112, 384)
138
+ seq = torch.randn(2, 112, 16)
139
139
 
140
140
  parallel_retrieved, state = mem(seq)
141
141
 
@@ -169,7 +169,7 @@ def test_mac(
169
169
  ):
170
170
  transformer = MemoryAsContextTransformer(
171
171
  num_tokens = 256,
172
- dim = 256,
172
+ dim = 16,
173
173
  depth = 2,
174
174
  num_persist_mem_tokens = num_persist_mem_tokens,
175
175
  num_longterm_mem_tokens = num_longterm_mem_tokens,
@@ -201,7 +201,7 @@ def test_mac_sampling(
201
201
  ):
202
202
  transformer = MemoryAsContextTransformer(
203
203
  num_tokens = 256,
204
- dim = 256,
204
+ dim = 16,
205
205
  depth = 4,
206
206
  segment_len = 32,
207
207
  num_persist_mem_tokens = 4,
@@ -235,12 +235,12 @@ def test_neural_mem_inference(
235
235
  ):
236
236
 
237
237
  mem = NeuralMemory(
238
- dim = 384,
238
+ dim = 16,
239
239
  chunk_size = mem_chunk_size,
240
240
  gated_transition = gated_transition
241
241
  )
242
242
 
243
- seq = torch.randn(2, seq_len, 384)
243
+ seq = torch.randn(2, seq_len, 16)
244
244
  parallel_retrieved, _ = mem(seq)
245
245
 
246
246
  assert seq.shape == parallel_retrieved.shape
@@ -282,7 +282,7 @@ def test_flex(
282
282
  pytest.skip()
283
283
 
284
284
  attn = SegmentedAttention(
285
- dim = 512,
285
+ dim = 16,
286
286
  segment_len = 32,
287
287
  num_persist_mem_tokens = 1,
288
288
  num_longterm_mem_tokens = 1,
@@ -290,7 +290,7 @@ def test_flex(
290
290
  sliding = sliding
291
291
  ).cuda()
292
292
 
293
- seq = torch.randn(1, seq_len, 512).cuda()
293
+ seq = torch.randn(1, seq_len, 16).cuda()
294
294
 
295
295
  out_flex, _ = attn(seq)
296
296
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
@@ -307,8 +307,8 @@ def test_assoc_scan():
307
307
  seq_len = 128
308
308
  mid_point = seq_len // 2
309
309
 
310
- gates = torch.randn(2, seq_len, 512).sigmoid()
311
- inputs = torch.randn(2, seq_len, 512)
310
+ gates = torch.randn(2, seq_len, 16).sigmoid()
311
+ inputs = torch.randn(2, seq_len, 16)
312
312
 
313
313
  output = scan(gates, inputs)
314
314
 
@@ -37,7 +37,7 @@ class MemoryMLP(Module):
37
37
  self,
38
38
  dim,
39
39
  depth,
40
- expansion_factor = 4.
40
+ expansion_factor = 2.
41
41
  ):
42
42
  super().__init__()
43
43
  dim_hidden = int(dim * expansion_factor)
@@ -690,16 +690,27 @@ class NeuralMemory(Module):
690
690
  def retrieve_memories(
691
691
  self,
692
692
  seq,
693
- past_weights: dict[str, Tensor],
694
- chunk_size = None,
695
- need_pad = True
693
+ weights: dict[str, Tensor],
696
694
  ):
697
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
695
+ chunk_size = self.retrieve_chunk_size
696
+
697
+ weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
698
+
698
699
  batch, seq_len = seq.shape[:2]
699
700
 
700
- seq = self.retrieve_norm(seq)
701
+ # auto infer single token decoding, if there are only 1 set of weights and 1 token
702
+
703
+ is_one_token = seq_len == 1
704
+ is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
705
+
706
+ is_single_token_decode = is_one_token and is_one_weight
707
+
708
+ if is_single_token_decode:
709
+ chunk_size = 1
710
+
711
+ # padding related, for chunked processing
701
712
 
702
- need_pad = need_pad or chunk_size > 1
713
+ need_pad = chunk_size > 1 or not is_one_weight
703
714
 
704
715
  if need_pad:
705
716
  seq = pad_at_dim(seq, (1, 0), dim = 1)
@@ -714,7 +725,11 @@ class NeuralMemory(Module):
714
725
  # the parameters of the memory model stores the memories of the key / values
715
726
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
716
727
 
717
- curr_weights = TensorDict(past_weights)
728
+ weights = TensorDict(weights)
729
+
730
+ # pre norm
731
+
732
+ seq = self.retrieve_norm(seq)
718
733
 
719
734
  # sequence Float['b n d'] to queries
720
735
 
@@ -730,14 +745,14 @@ class NeuralMemory(Module):
730
745
 
731
746
  # fetch values from memory model
732
747
 
733
- if dict_get_shape(curr_weights) != self.init_weight_shape:
734
- curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
748
+ if weights_have_expanded_shape:
749
+ weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
735
750
 
736
751
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
737
752
 
738
753
  # forward functional call
739
754
 
740
- values = functional_call(self.memory_model, dict(curr_weights), queries)
755
+ values = functional_call(self.memory_model, dict(weights), queries)
741
756
 
742
757
  # reconstitute batch dimension
743
758
 
@@ -885,22 +900,13 @@ class NeuralMemory(Module):
885
900
 
886
901
  # retrieve
887
902
 
888
- need_pad = True
889
- retrieve_chunk_size = None
890
-
891
903
  if is_single_token:
892
- retrieve_chunk_size = 1
893
- need_pad = False
894
-
895
904
  last_update, _ = next_neural_mem_state.states
896
-
897
905
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
898
906
 
899
907
  retrieved = self.retrieve_memories(
900
908
  seq,
901
- updates,
902
- chunk_size = retrieve_chunk_size,
903
- need_pad = need_pad,
909
+ updates
904
910
  )
905
911
 
906
912
  return retrieved, next_neural_mem_state
File without changes
File without changes
File without changes
File without changes