titans-pytorch 0.3.1__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.1"
3
+ version = "0.3.3"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -42,7 +42,7 @@ def test_titans(
42
42
  per_parameter_lr_modulation
43
43
  ):
44
44
  mem = NeuralMemory(
45
- dim = 384,
45
+ dim = 16,
46
46
  chunk_size = chunk_size,
47
47
  activation = nn.SiLU() if silu else None,
48
48
  attn_pool_chunks = attn_pool_chunks,
@@ -52,7 +52,7 @@ def test_titans(
52
52
  per_parameter_lr_modulation = per_parameter_lr_modulation,
53
53
  )
54
54
 
55
- seq = torch.randn(2, seq_len, 384)
55
+ seq = torch.randn(2, seq_len, 16)
56
56
  retrieved, _ = mem(seq)
57
57
 
58
58
  assert seq.shape == retrieved.shape
@@ -61,14 +61,14 @@ def test_titans_attn_memory():
61
61
  from titans_pytorch.memory_models import MemoryAttention
62
62
 
63
63
  mem = NeuralMemory(
64
- dim = 384,
64
+ dim = 16,
65
65
  chunk_size = 64,
66
66
  model = MemoryAttention(
67
- dim = 384
67
+ dim = 16
68
68
  )
69
69
  )
70
70
 
71
- seq = torch.randn(2, 1024, 384)
71
+ seq = torch.randn(2, 1024, 16)
72
72
  retrieved, _ = mem(seq)
73
73
 
74
74
  assert seq.shape == retrieved.shape
@@ -78,14 +78,14 @@ def test_neural_mem_chaining_chunks(
78
78
  gated_transition
79
79
  ):
80
80
  mem = NeuralMemory(
81
- dim = 384,
82
- dim_head = 64,
81
+ dim = 16,
82
+ dim_head = 16,
83
83
  heads = 2,
84
84
  chunk_size = 16,
85
85
  gated_transition = gated_transition
86
86
  )
87
87
 
88
- seq = torch.randn(2, 48, 384)
88
+ seq = torch.randn(2, 48, 16)
89
89
 
90
90
  parallel_retrieved, state = mem(seq)
91
91
 
@@ -99,21 +99,21 @@ def test_neural_mem_chaining_chunks(
99
99
 
100
100
  def test_neural_mem_chaining_with_weight_residual():
101
101
  mem = NeuralMemory(
102
- dim = 384,
103
- dim_head = 64,
102
+ dim = 16,
103
+ dim_head = 16,
104
104
  heads = 2,
105
105
  chunk_size = 64
106
106
  )
107
107
 
108
108
  mem2 = NeuralMemory(
109
- dim = 384,
110
- dim_head = 64,
109
+ dim = 16,
110
+ dim_head = 16,
111
111
  heads = 2,
112
112
  chunk_size = 64,
113
113
  accept_weight_residual = True
114
114
  )
115
115
 
116
- seq = torch.randn(2, 256, 384)
116
+ seq = torch.randn(2, 256, 16)
117
117
 
118
118
  seq, state = mem(seq)
119
119
 
@@ -124,18 +124,18 @@ def test_neural_mem_chaining_with_weight_residual():
124
124
  first_retrieved, state1 = mem2(seq_first, prev_weights = state.updates)
125
125
  second_retrieved, state2 = mem2(seq_second, state = state1, prev_weights = state.updates)
126
126
 
127
- assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-6)
127
+ assert torch.allclose(parallel_retrieved, torch.cat((first_retrieved, second_retrieved), dim = 1), atol = 1e-5)
128
128
 
129
129
  def test_neural_mem_chaining_with_batch_size():
130
130
  mem = NeuralMemory(
131
- dim = 384,
132
- dim_head = 64,
131
+ dim = 16,
132
+ dim_head = 16,
133
133
  heads = 2,
134
134
  chunk_size = 16,
135
135
  batch_size = 64
136
136
  )
137
137
 
138
- seq = torch.randn(2, 112, 384)
138
+ seq = torch.randn(2, 112, 16)
139
139
 
140
140
  parallel_retrieved, state = mem(seq)
141
141
 
@@ -169,7 +169,7 @@ def test_mac(
169
169
  ):
170
170
  transformer = MemoryAsContextTransformer(
171
171
  num_tokens = 256,
172
- dim = 256,
172
+ dim = 16,
173
173
  depth = 2,
174
174
  num_persist_mem_tokens = num_persist_mem_tokens,
175
175
  num_longterm_mem_tokens = num_longterm_mem_tokens,
@@ -201,7 +201,7 @@ def test_mac_sampling(
201
201
  ):
202
202
  transformer = MemoryAsContextTransformer(
203
203
  num_tokens = 256,
204
- dim = 256,
204
+ dim = 16,
205
205
  depth = 4,
206
206
  segment_len = 32,
207
207
  num_persist_mem_tokens = 4,
@@ -235,12 +235,12 @@ def test_neural_mem_inference(
235
235
  ):
236
236
 
237
237
  mem = NeuralMemory(
238
- dim = 384,
238
+ dim = 16,
239
239
  chunk_size = mem_chunk_size,
240
240
  gated_transition = gated_transition
241
241
  )
242
242
 
243
- seq = torch.randn(2, seq_len, 384)
243
+ seq = torch.randn(2, seq_len, 16)
244
244
  parallel_retrieved, _ = mem(seq)
245
245
 
246
246
  assert seq.shape == parallel_retrieved.shape
@@ -282,7 +282,7 @@ def test_flex(
282
282
  pytest.skip()
283
283
 
284
284
  attn = SegmentedAttention(
285
- dim = 512,
285
+ dim = 16,
286
286
  segment_len = 32,
287
287
  num_persist_mem_tokens = 1,
288
288
  num_longterm_mem_tokens = 1,
@@ -290,7 +290,7 @@ def test_flex(
290
290
  sliding = sliding
291
291
  ).cuda()
292
292
 
293
- seq = torch.randn(1, seq_len, 512).cuda()
293
+ seq = torch.randn(1, seq_len, 16).cuda()
294
294
 
295
295
  out_flex, _ = attn(seq)
296
296
  out_non_flex, _ = attn(seq, disable_flex_attn = True)
@@ -307,8 +307,8 @@ def test_assoc_scan():
307
307
  seq_len = 128
308
308
  mid_point = seq_len // 2
309
309
 
310
- gates = torch.randn(2, seq_len, 512).sigmoid()
311
- inputs = torch.randn(2, seq_len, 512)
310
+ gates = torch.randn(2, seq_len, 16).sigmoid()
311
+ inputs = torch.randn(2, seq_len, 16)
312
312
 
313
313
  output = scan(gates, inputs)
314
314
 
@@ -36,10 +36,14 @@ class MemoryMLP(Module):
36
36
  def __init__(
37
37
  self,
38
38
  dim,
39
- depth
39
+ depth,
40
+ expansion_factor = 2.
40
41
  ):
41
42
  super().__init__()
42
- self.weights = ParameterList([Parameter(torch.randn(dim, dim)) for _ in range(depth)])
43
+ dim_hidden = int(dim * expansion_factor)
44
+ dims = (dim, *((dim_hidden,) * (depth - 1)), dim)
45
+
46
+ self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
43
47
 
44
48
  self.ln = LayerNorm(dim)
45
49
 
@@ -299,7 +299,8 @@ class NeuralMemory(Module):
299
299
  accept_weight_residual = False,
300
300
  gated_transition = False,
301
301
  default_model_kwargs: dict = dict(
302
- depth = 2
302
+ depth = 2,
303
+ expansion_factor = 4.
303
304
  )
304
305
  ):
305
306
  super().__init__()
@@ -689,16 +690,27 @@ class NeuralMemory(Module):
689
690
  def retrieve_memories(
690
691
  self,
691
692
  seq,
692
- past_weights: dict[str, Tensor],
693
- chunk_size = None,
694
- need_pad = True
693
+ weights: dict[str, Tensor],
695
694
  ):
696
- chunk_size = default(chunk_size, self.retrieve_chunk_size)
695
+ chunk_size = self.retrieve_chunk_size
696
+
697
+ weights_have_expanded_shape = dict_get_shape(weights) != self.init_weight_shape
698
+
697
699
  batch, seq_len = seq.shape[:2]
698
700
 
699
- seq = self.retrieve_norm(seq)
701
+ # auto infer single token decoding, if there are only 1 set of weights and 1 token
702
+
703
+ is_one_token = seq_len == 1
704
+ is_one_weight = (not weights_have_expanded_shape) or next(iter(weights.values())).shape[1] == 1
705
+
706
+ is_single_token_decode = is_one_token and is_one_weight
707
+
708
+ if is_single_token_decode:
709
+ chunk_size = 1
710
+
711
+ # padding related, for chunked processing
700
712
 
701
- need_pad = need_pad or chunk_size > 1
713
+ need_pad = chunk_size > 1 or not is_one_weight
702
714
 
703
715
  if need_pad:
704
716
  seq = pad_at_dim(seq, (1, 0), dim = 1)
@@ -713,7 +725,11 @@ class NeuralMemory(Module):
713
725
  # the parameters of the memory model stores the memories of the key / values
714
726
  # when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
715
727
 
716
- curr_weights = TensorDict(past_weights)
728
+ weights = TensorDict(weights)
729
+
730
+ # pre norm
731
+
732
+ seq = self.retrieve_norm(seq)
717
733
 
718
734
  # sequence Float['b n d'] to queries
719
735
 
@@ -729,14 +745,14 @@ class NeuralMemory(Module):
729
745
 
730
746
  # fetch values from memory model
731
747
 
732
- if dict_get_shape(curr_weights) != self.init_weight_shape:
733
- curr_weights = rearrange_dict_values(curr_weights, 'b n ... -> (b n) ...')
748
+ if weights_have_expanded_shape:
749
+ weights = rearrange_dict_values(weights, 'b n ... -> (b n) ...')
734
750
 
735
751
  queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
736
752
 
737
753
  # forward functional call
738
754
 
739
- values = functional_call(self.memory_model, dict(curr_weights), queries)
755
+ values = functional_call(self.memory_model, dict(weights), queries)
740
756
 
741
757
  # reconstitute batch dimension
742
758
 
@@ -884,22 +900,13 @@ class NeuralMemory(Module):
884
900
 
885
901
  # retrieve
886
902
 
887
- need_pad = True
888
- retrieve_chunk_size = None
889
-
890
903
  if is_single_token:
891
- retrieve_chunk_size = 1
892
- need_pad = False
893
-
894
904
  last_update, _ = next_neural_mem_state.states
895
-
896
905
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
897
906
 
898
907
  retrieved = self.retrieve_memories(
899
908
  seq,
900
- updates,
901
- chunk_size = retrieve_chunk_size,
902
- need_pad = need_pad,
909
+ updates
903
910
  )
904
911
 
905
912
  return retrieved, next_neural_mem_state
File without changes
File without changes
File without changes
File without changes