titans-pytorch 0.3.15__py3-none-any.whl → 0.3.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -483,6 +483,7 @@ class MemoryAsContextTransformer(Module):
483
483
  num_longterm_mem_tokens = 0,
484
484
  num_persist_mem_tokens = 0,
485
485
  neural_memory_batch_size = None,
486
+ neural_memory_qkv_receives_diff_views = False,
486
487
  dim_head = 64,
487
488
  heads = 8,
488
489
  ff_mult = 4,
@@ -560,13 +561,14 @@ class MemoryAsContextTransformer(Module):
560
561
  mem_hyper_conn = None
561
562
 
562
563
  if layer in neural_memory_layers:
563
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
564
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
564
565
 
565
566
  mem = NeuralMemory(
566
567
  dim = dim,
567
568
  chunk_size = self.neural_memory_segment_len,
568
569
  batch_size = neural_memory_batch_size,
569
570
  model = deepcopy(neural_memory_model),
571
+ qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
570
572
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
571
573
  **neural_memory_kwargs
572
574
  )
@@ -231,6 +231,7 @@ class NeuralMemory(Module):
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
233
  learned_combine_include_zeroth = False,
234
+ qkv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
234
235
  pre_rmsnorm = True,
235
236
  post_rmsnorm = False,
236
237
  qk_rmsnorm = False,
@@ -265,6 +266,10 @@ class NeuralMemory(Module):
265
266
 
266
267
  self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
267
268
 
269
+ # key values receiving different views
270
+
271
+ self.qkv_receives_diff_views = qkv_receives_diff_views
272
+
268
273
  # norms
269
274
 
270
275
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -358,7 +363,9 @@ class NeuralMemory(Module):
358
363
 
359
364
  # keys and values for storing to the model
360
365
 
361
- self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
366
+ self.to_keys = Sequential(LinearNoBias(dim, dim_inner), activation)
367
+ self.to_values = Sequential(LinearNoBias(dim, dim_inner), activation)
368
+
362
369
  self.store_memory_loss_fn = store_memory_loss_fn
363
370
 
364
371
  # `chunk_size` refers to chunk size used for storing to memory model weights
@@ -504,7 +511,14 @@ class NeuralMemory(Module):
504
511
  seq_index = 0,
505
512
  prev_weights = None
506
513
  ):
507
- batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
514
+ if self.qkv_receives_diff_views:
515
+ _, batch, seq_len = seq.shape[:3]
516
+ else:
517
+ batch, seq_len = seq.shape[:2]
518
+
519
+ # shapes and variables
520
+
521
+ heads, chunk_size = self.heads, self.store_chunk_size
508
522
 
509
523
  # curtail sequence by multiple of the chunk size
510
524
  # only a complete chunk of the sequence provides the memory for the next chunk
@@ -512,7 +526,7 @@ class NeuralMemory(Module):
512
526
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
513
527
  num_chunks = round_down_seq_len // chunk_size
514
528
 
515
- seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
529
+ seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :]
516
530
 
517
531
  next_seq_len_index = seq_index + round_down_seq_len
518
532
 
@@ -528,10 +542,19 @@ class NeuralMemory(Module):
528
542
 
529
543
  weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
530
544
 
531
- # derive learned hparams for optimization of memory network
545
+ # initial norm
532
546
 
533
547
  seq = self.store_norm(seq)
534
548
 
549
+ # handle keys and values coming from different sequences from hyper connection
550
+
551
+ values_seq = seq
552
+
553
+ if self.qkv_receives_diff_views:
554
+ seq, values_seq = seq
555
+
556
+ # derive learned hparams for optimization of memory network
557
+
535
558
  adaptive_lr = self.to_adaptive_step(seq)
536
559
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
537
560
 
@@ -555,7 +578,8 @@ class NeuralMemory(Module):
555
578
 
556
579
  # keys and values
557
580
 
558
- keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
581
+ keys = self.to_keys(seq)
582
+ values = self.to_values(values_seq)
559
583
 
560
584
  # maybe multi head
561
585
 
@@ -796,10 +820,23 @@ class NeuralMemory(Module):
796
820
  state: NeuralMemState | None = None,
797
821
  prev_weights = None
798
822
  ):
799
- if seq.ndim == 2:
800
- seq = rearrange(seq, 'b d -> b 1 d')
823
+ is_multi_input = self.qkv_receives_diff_views
824
+
825
+ # handle single token
826
+
827
+ if seq.ndim == 2 or (is_multi_input and seq.ndim == 3):
828
+ seq = rearrange(seq, '... b d -> ... b 1 d')
829
+
830
+ is_single_token = seq.shape[-2] == 1
801
831
 
802
- is_single_token = seq.shape[1] == 1
832
+ # if different views for qkv, then
833
+
834
+ if is_multi_input:
835
+ retrieve_seq, seq = seq[0], seq[1:]
836
+ else:
837
+ retrieve_seq = seq
838
+
839
+ # handle previous state init
803
840
 
804
841
  if not exists(state):
805
842
  state = (0, None, None, None, None)
@@ -815,8 +852,6 @@ class NeuralMemory(Module):
815
852
  if exists(cache_store_seq):
816
853
  store_seq = safe_cat((cache_store_seq, store_seq))
817
854
 
818
- # functions
819
-
820
855
  # compute split sizes of sequence
821
856
  # for now manually update weights to last update at the correct boundaries
822
857
 
@@ -916,7 +951,7 @@ class NeuralMemory(Module):
916
951
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
917
952
 
918
953
  retrieved = self.retrieve_memories(
919
- seq,
954
+ retrieve_seq,
920
955
  updates
921
956
  )
922
957
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.15
3
+ Version: 0.3.20
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.10
41
+ Requires-Dist: hyper-connections>=0.1.11
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=M45tLp_CzYVm9R33zkUHmWdzIuaNxNYQLCTAtyMechg,25294
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=hLMDbLosxgQZvvfuJM3X7w23npcoL9AWBQfbtQHOSiA,30637
6
+ titans_pytorch-0.3.20.dist-info/METADATA,sha256=JILf0r0bT4KXAZD4V2nmYRIr5Y7bGGjjNlpfIuT0UNs,6817
7
+ titans_pytorch-0.3.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.20.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=HIB3S3JBA8Fe1EBITvDZSHXtn-1_fF1rwlw-MzqagKY,25085
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
- titans_pytorch-0.3.15.dist-info/METADATA,sha256=RPw9JXenAI7cGpVP3hQZlj0OA5-xsvXvXHvxyhWdgpg,6817
7
- titans_pytorch-0.3.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.15.dist-info/RECORD,,