titans-pytorch 0.3.14__py3-none-any.whl → 0.3.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -483,6 +483,7 @@ class MemoryAsContextTransformer(Module):
483
483
  num_longterm_mem_tokens = 0,
484
484
  num_persist_mem_tokens = 0,
485
485
  neural_memory_batch_size = None,
486
+ neural_memory_kv_receives_diff_views = False,
486
487
  dim_head = 64,
487
488
  heads = 8,
488
489
  ff_mult = 4,
@@ -494,7 +495,6 @@ class MemoryAsContextTransformer(Module):
494
495
  sliding_window_attn = False,
495
496
  neural_mem_weight_residual = False,
496
497
  token_emb: Module | None = None,
497
- abs_pos_emb: Module | None = None
498
498
  ):
499
499
  super().__init__()
500
500
 
@@ -503,10 +503,9 @@ class MemoryAsContextTransformer(Module):
503
503
 
504
504
  self.token_emb = token_emb
505
505
 
506
- if not exists(abs_pos_emb):
507
- abs_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
506
+ # absolute positions
508
507
 
509
- self.abs_pos_emb = abs_pos_emb
508
+ self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
510
509
 
511
510
  # long term mem tokens
512
511
 
@@ -562,13 +561,14 @@ class MemoryAsContextTransformer(Module):
562
561
  mem_hyper_conn = None
563
562
 
564
563
  if layer in neural_memory_layers:
565
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
564
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 2 if neural_memory_kv_receives_diff_views else 1)
566
565
 
567
566
  mem = NeuralMemory(
568
567
  dim = dim,
569
568
  chunk_size = self.neural_memory_segment_len,
570
569
  batch_size = neural_memory_batch_size,
571
570
  model = deepcopy(neural_memory_model),
571
+ kv_receives_diff_views = neural_memory_kv_receives_diff_views,
572
572
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
573
573
  **neural_memory_kwargs
574
574
  )
@@ -231,6 +231,7 @@ class NeuralMemory(Module):
231
231
  momentum_order = 1,
232
232
  learned_momentum_combine = False,
233
233
  learned_combine_include_zeroth = False,
234
+ kv_receives_diff_views = False, # to address an issue raised by a phd student (who will be credited if experiments are green). basically the issue raised is that the memory MLP is only learning Wk @ Wv linear mapping and that may not be expressive enough. we will use hyper connections to allow the network to choose different previous layer inputs as keys / values and see if that does anything
234
235
  pre_rmsnorm = True,
235
236
  post_rmsnorm = False,
236
237
  qk_rmsnorm = False,
@@ -265,6 +266,10 @@ class NeuralMemory(Module):
265
266
 
266
267
  self.assoc_scan = AssocScan(use_accelerated = use_accelerated_scan)
267
268
 
269
+ # key values receiving different views
270
+
271
+ self.kv_receives_diff_views = kv_receives_diff_views
272
+
268
273
  # norms
269
274
 
270
275
  self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
@@ -358,7 +363,9 @@ class NeuralMemory(Module):
358
363
 
359
364
  # keys and values for storing to the model
360
365
 
361
- self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
366
+ self.to_keys = Sequential(LinearNoBias(dim, dim_inner), activation)
367
+ self.to_values = Sequential(LinearNoBias(dim, dim_inner), activation)
368
+
362
369
  self.store_memory_loss_fn = store_memory_loss_fn
363
370
 
364
371
  # `chunk_size` refers to chunk size used for storing to memory model weights
@@ -504,7 +511,14 @@ class NeuralMemory(Module):
504
511
  seq_index = 0,
505
512
  prev_weights = None
506
513
  ):
507
- batch, seq_len, heads, chunk_size = *seq.shape[:2], self.heads, self.store_chunk_size
514
+ if self.kv_receives_diff_views:
515
+ _, batch, seq_len = seq.shape[:3]
516
+ else:
517
+ batch, seq_len = seq.shape[:2]
518
+
519
+ # shapes and variables
520
+
521
+ heads, chunk_size = self.heads, self.store_chunk_size
508
522
 
509
523
  # curtail sequence by multiple of the chunk size
510
524
  # only a complete chunk of the sequence provides the memory for the next chunk
@@ -512,7 +526,7 @@ class NeuralMemory(Module):
512
526
  round_down_seq_len = round_down_multiple(seq_len, chunk_size)
513
527
  num_chunks = round_down_seq_len // chunk_size
514
528
 
515
- seq, remainder = seq[:, :round_down_seq_len], seq[:, round_down_seq_len:]
529
+ seq, remainder = seq[..., :round_down_seq_len, :], seq[..., round_down_seq_len:, :]
516
530
 
517
531
  next_seq_len_index = seq_index + round_down_seq_len
518
532
 
@@ -528,10 +542,19 @@ class NeuralMemory(Module):
528
542
 
529
543
  weights_for_surprise = repeat_dict_values(weights, 'b ... -> b n ...', n = num_chunks)
530
544
 
531
- # derive learned hparams for optimization of memory network
545
+ # initial norm
532
546
 
533
547
  seq = self.store_norm(seq)
534
548
 
549
+ # handle keys and values coming from different sequences from hyper connection
550
+
551
+ values_seq = seq
552
+
553
+ if self.kv_receives_diff_views:
554
+ seq, values_seq = seq
555
+
556
+ # derive learned hparams for optimization of memory network
557
+
535
558
  adaptive_lr = self.to_adaptive_step(seq)
536
559
  adaptive_lr = self.adaptive_step_transform(adaptive_lr)
537
560
 
@@ -555,7 +578,8 @@ class NeuralMemory(Module):
555
578
 
556
579
  # keys and values
557
580
 
558
- keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
581
+ keys = self.to_keys(seq)
582
+ values = self.to_values(values_seq)
559
583
 
560
584
  # maybe multi head
561
585
 
@@ -915,6 +939,9 @@ class NeuralMemory(Module):
915
939
  last_update, _ = next_neural_mem_state.states
916
940
  updates = rearrange_dict_values(last_update, 'b ... -> b 1 ...')
917
941
 
942
+ if self.kv_receives_diff_views:
943
+ seq = seq[0]
944
+
918
945
  retrieved = self.retrieve_memories(
919
946
  seq,
920
947
  updates
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.14
3
+ Version: 0.3.19
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
38
38
  Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
- Requires-Dist: hyper-connections>=0.1.10
41
+ Requires-Dist: hyper-connections>=0.1.11
42
42
  Requires-Dist: ninja
43
43
  Requires-Dist: rotary-embedding-torch
44
44
  Requires-Dist: tensordict
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=4N8WcoqPYNhMGGQAZjpm-djVsLnU7VADH_l06qFPuOk,25290
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=KKj8r3oRoDYf-vFQFnI4Rd4DfMH1f5QFs4vdcS35og8,30374
6
+ titans_pytorch-0.3.19.dist-info/METADATA,sha256=5h4f5gsO1emX5LEwe8cgpH35Rtjydz2UTlpK4DKSntI,6817
7
+ titans_pytorch-0.3.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.19.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.19.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=F04B88GaH0wHseUIWaX6VFhOSsk_3XDQ1E8e6pvqKgQ,25170
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=VmUAS1xOM0ZfearWIzQrX_P7HI69viuwrg9M7BQByeE,29349
6
- titans_pytorch-0.3.14.dist-info/METADATA,sha256=1reoUZhvKaFPR6U0QXqJOziyss0HwHhwfJUf7oU7t-s,6817
7
- titans_pytorch-0.3.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.14.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.14.dist-info/RECORD,,