titans-pytorch 0.3.19__tar.gz → 0.3.20__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.19
3
+ Version: 0.3.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.19"
3
+ version = "0.3.20"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -200,7 +200,7 @@ def test_neural_mem_chaining_with_batch_size():
200
200
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
201
201
  @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
202
202
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
203
- @pytest.mark.parametrize('neural_mem_kv_receives_diff_views', (False, True))
203
+ @pytest.mark.parametrize('neural_mem_qkv_receives_diff_views', (False, True))
204
204
  @pytest.mark.parametrize('neural_mem_momentum', (False, True))
205
205
  def test_mac(
206
206
  seq_len,
@@ -210,7 +210,7 @@ def test_mac(
210
210
  neural_mem_segment_len,
211
211
  neural_mem_weight_residual,
212
212
  neural_mem_batch_size,
213
- neural_mem_kv_receives_diff_views,
213
+ neural_mem_qkv_receives_diff_views,
214
214
  neural_mem_momentum
215
215
  ):
216
216
  transformer = MemoryAsContextTransformer(
@@ -223,7 +223,7 @@ def test_mac(
223
223
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
224
224
  neural_memory_segment_len = neural_mem_segment_len,
225
225
  neural_memory_batch_size = neural_mem_batch_size,
226
- neural_memory_kv_receives_diff_views = neural_mem_kv_receives_diff_views,
226
+ neural_memory_qkv_receives_diff_views = neural_mem_qkv_receives_diff_views,
227
227
  neural_mem_weight_residual = neural_mem_weight_residual,
228
228
  neural_memory_kwargs = dict(
229
229
  momentum = neural_mem_momentum
@@ -483,7 +483,7 @@ class MemoryAsContextTransformer(Module):
483
483
  num_longterm_mem_tokens = 0,
484
484
  num_persist_mem_tokens = 0,
485
485
  neural_memory_batch_size = None,
486
- neural_memory_kv_receives_diff_views = False,
486
+ neural_memory_qkv_receives_diff_views = False,
487
487
  dim_head = 64,
488
488
  heads = 8,
489
489
  ff_mult = 4,
@@ -561,14 +561,14 @@ class MemoryAsContextTransformer(Module):
561
561
  mem_hyper_conn = None
562
562
 
563
563
  if layer in neural_memory_layers:
564
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 2 if neural_memory_kv_receives_diff_views else 1)
564
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
565
565
 
566
566
  mem = NeuralMemory(
567
567
  dim = dim,
568
568
  chunk_size = self.neural_memory_segment_len,
569
569
  batch_size = neural_memory_batch_size,
570
570
  model = deepcopy(neural_memory_model),
571
- kv_receives_diff_views = neural_memory_kv_receives_diff_views,
571
+ qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
572
572
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
573
573
  **neural_memory_kwargs
574
574
  )
@@ -231,7 +231,7 @@ class NeuralMemory(Module):
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
233
  learned_combine_include_zeroth = False,
234
- kv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
234
+ qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
235
235
  pre_rmsnorm = True,
236
236
  post_rmsnorm = False,
237
237
  qk_rmsnorm = False,
@@ -268,7 +268,7 @@ class NeuralMemory(Module):
268
268
 
269
269
  # key values receiving different views
270
270
 
271
- self.kv_receives_diff_views = kv_receives_diff_views
271
+ self.qkv_receives_diff_views = qkv_receives_diff_views
272
272
 
273
273
  # norms
274
274
 
@@ -511,7 +511,7 @@ class NeuralMemory(Module):
511
511
  seq_index = 0,
512
512
  prev_weights = None
513
513
  ):
514
- if self.kv_receives_diff_views:
514
+ if self.qkv_receives_diff_views:
515
515
  _, batch, seq_len = seq.shape[:3]
516
516
  else:
517
517
  batch, seq_len = seq.shape[:2]
@@ -550,7 +550,7 @@ class NeuralMemory(Module):
550
550
 
551
551
  values_seq = seq
552
552
 
553
- if self.kv_receives_diff_views:
553
+ if self.qkv_receives_diff_views:
554
554
  seq, values_seq = seq
555
555
 
556
556
  # derive learned hparams for optimization of memory network
@@ -820,10 +820,23 @@ class NeuralMemory(Module):
820
820
  state: NeuralMemState | None = None,
821
821
  prev_weights = None
822
822
  ):
823
- if seq.ndim == 2:
824
- seq = rearrange(seq, 'b d -> b 1 d')
823
+ is_multi_input = self.qkv_receives_diff_views
825
824
 
826
- is_single_token = seq.shape[1] == 1
825
+ # handle single token
826
+
827
+ if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
828
+ seq = rearrange(seq, '... b d -> ... b 1 d')
829
+
830
+ is_single_token = seq.shape[-2] == 1
831
+
832
+ # if different views for qkv, then
833
+
834
+ if is_multi_input:
835
+ retrieve_seq, seq = seq[0], seq[1:]
836
+ else:
837
+ retrieve_seq = seq
838
+
839
+ # handle previous state init
827
840
 
828
841
  if not exists(state):
829
842
  state = (0, None, None, None, None)
@@ -839,8 +852,6 @@ class NeuralMemory(Module):
839
852
  if exists(cache_store_seq):
840
853
  store_seq = safe_cat((cache_store_seq, store_seq))
841
854
 
842
- # functions
843
-
844
855
  # compute split sizes of sequence
845
856
  # for now manually update weights to last update at the correct boundaries
846
857
 
@@ -939,11 +950,8 @@ class NeuralMemory(Module):
939
950
  last_update, _ = next_neural_mem_state.states
940
951
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
941
952
 
942
- if self.kv_receives_diff_views:
943
- seq = seq[0]
944
-
945
953
  retrieved = self.retrieve_memories(
946
- seq,
954
+ retrieve_seq,
947
955
  updates
948
956
  )
949
957
 
File without changes