titans-pytorch 0.3.15__tar.gz → 0.3.19__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.15
3
+ Version: 0.3.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.10
41
+ Requires-Dist: hyper-connections>=0.1.11
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.3.15"
3
+ version = "0.3.19"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -29,7 +29,7 @@ dependencies = [
29
29
  "axial_positional_embedding>=0.3.10",
30
30
  "einops>=0.8.0",
31
31
  "einx>=0.3.0",
32
- "hyper-connections>=0.1.10",
32
+ "hyper-connections>=0.1.11",
33
33
  "Ninja",
34
34
  "rotary-embedding-torch",
35
35
  "tensordict",
@@ -200,6 +200,7 @@ def test_neural_mem_chaining_with_batch_size():
200
200
  @pytest.mark.parametrize('neural_mem_segment_len', (8, 16))
201
201
  @pytest.mark.parametrize('neural_mem_weight_residual', (False, True))
202
202
  @pytest.mark.parametrize('neural_mem_batch_size', (None, 64))
203
+ @pytest.mark.parametrize('neural_mem_kv_receives_diff_views', (False, True))
203
204
  @pytest.mark.parametrize('neural_mem_momentum', (False, True))
204
205
  def test_mac(
205
206
  seq_len,
@@ -209,6 +210,7 @@ def test_mac(
209
210
  neural_mem_segment_len,
210
211
  neural_mem_weight_residual,
211
212
  neural_mem_batch_size,
213
+ neural_mem_kv_receives_diff_views,
212
214
  neural_mem_momentum
213
215
  ):
214
216
  transformer = MemoryAsContextTransformer(
@@ -221,6 +223,7 @@ def test_mac(
221
223
  neural_mem_gate_attn_output = neural_mem_gate_attn_output,
222
224
  neural_memory_segment_len = neural_mem_segment_len,
223
225
  neural_memory_batch_size = neural_mem_batch_size,
226
+ neural_memory_kv_receives_diff_views = neural_mem_kv_receives_diff_views,
224
227
  neural_mem_weight_residual = neural_mem_weight_residual,
225
228
  neural_memory_kwargs = dict(
226
229
  momentum = neural_mem_momentum
@@ -483,6 +483,7 @@ class MemoryAsContextTransformer(Module):
483
483
  num_longterm_mem_tokens = 0,
484
484
  num_persist_mem_tokens = 0,
485
485
  neural_memory_batch_size = None,
486
+ neural_memory_kv_receives_diff_views = False,
486
487
  dim_head = 64,
487
488
  heads = 8,
488
489
  ff_mult = 4,
@@ -560,13 +561,14 @@ class MemoryAsContextTransformer(Module):
560
561
  mem_hyper_conn = None
561
562
 
562
563
  if layer in neural_memory_layers:
563
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
564
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 2 if neural_memory_kv_receives_diff_views else 1)
564
565
 
565
566
  mem = NeuralMemory(
566
567
  dim = dim,
567
568
  chunk_size = self.neural_memory_segment_len,
568
569
  batch_size = neural_memory_batch_size,
569
570
  model = deepcopy(neural_memory_model),
571
+ kv_receives_diff_views = neural_memory_kv_receives_diff_views,
570
572
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
571
573
  **neural_memory_kwargs
572
574
  )
@@ -231,6 +231,7 @@ class NeuralMemory(Module):
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
233
  learned_combine_include_zeroth = False,
234
+ kv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
234
235
  pre_rmsnorm = True,
235
236
  post_rmsnorm = False,
236
237
  qk_rmsnorm = False,
@@ -265,6 +266,10 @@ class NeuralMemory(Module):
265
266
 
266
267
  self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
267
268
 
269
+ # key values receiving different views
270
+
271
+ self.kv_receives_diff_views = kv_receives_diff_views
272
+
268
273
  # norms
269
274
 
270
275
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -358,7 +363,9 @@ class NeuralMemory(Module):
358
363
 
359
364
  # keys and values for storing to the model
360
365
 
361
- self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
366
+ self.to_keys = Sequential(LinearNoBias(dim, dim_inner), activation)
367
+ self.to_values = Sequential(LinearNoBias(dim, dim_inner), activation)
368
+
362
369
  self.store_memory_loss_fn = store_memory_loss_fn
363
370
 
364
371
  # `chunk_size` refers to chunk size used for storing to memory model weights
@@ -504,7 +511,14 @@ class NeuralMemory(Module):
504
511
  seq_index = 0,
505
512
  prev_weights = None
506
513
  ):
507
- batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
514
+ if self.kv_receives_diff_views:
515
+ _, batch, seq_len = seq.shape[:3]
516
+ else:
517
+ batch, seq_len = seq.shape[:2]
518
+
519
+ # shapes and variables
520
+
521
+ heads, chunk_size = self.heads, self.store_chunk_size
508
522
 
509
523
  # curtail sequence by multiple of the chunk size
510
524
  # only a complete chunk of the sequence provides the memory for the next chunk
@@ -512,7 +526,7 @@ class NeuralMemory(Module):
512
526
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
513
527
  num_chunks = round_down_seq_len // chunk_size
514
528
 
515
- seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
529
+ seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :]
516
530
 
517
531
  next_seq_len_index = seq_index + round_down_seq_len
518
532
 
@@ -528,10 +542,19 @@ class NeuralMemory(Module):
528
542
 
529
543
  weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
530
544
 
531
- # derive learned hparams for optimization of memory network
545
+ # initial norm
532
546
 
533
547
  seq = self.store_norm(seq)
534
548
 
549
+ # handle keys and values coming from different sequences from hyper connection
550
+
551
+ values_seq = seq
552
+
553
+ if self.kv_receives_diff_views:
554
+ seq, values_seq = seq
555
+
556
+ # derive learned hparams for optimization of memory network
557
+
535
558
  adaptive_lr = self.to_adaptive_step(seq)
536
559
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
537
560
 
@@ -555,7 +578,8 @@ class NeuralMemory(Module):
555
578
 
556
579
  # keys and values
557
580
 
558
- keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
581
+ keys = self.to_keys(seq)
582
+ values = self.to_values(values_seq)
559
583
 
560
584
  # maybe multi head
561
585
 
@@ -915,6 +939,9 @@ class NeuralMemory(Module):
915
939
  last_update, _ = next_neural_mem_state.states
916
940
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
917
941
 
942
+ if self.kv_receives_diff_views:
943
+ seq = seq[0]
944
+
918
945
  retrieved = self.retrieve_memories(
919
946
  seq,
920
947
  updates
File without changes