titans-pytorch 0.3.25__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,7 +46,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False
46
46
 
47
47
  # einstein notation related
48
48
 
49
- from einops import repeat, rearrange, pack, unpack
49
+ from einops import repeat, rearrange, pack, unpack, einsum
50
50
  from einops.layers.torch import Rearrange
51
51
 
52
52
  # b - batch
@@ -521,9 +521,7 @@ class MemoryAsContextTransformer(Module):
521
521
  self.sliding_window_attn = sliding_window_attn
522
522
  self.attn_window_size = segment_len + num_longterm_mem_tokens
523
523
 
524
- # hyper conection
525
-
526
- assert not (num_residual_streams <= 1 and neural_memory_qkv_receives_diff_views), 'allow neural memory queries, keys, values to be derived from different combinations of the residual streams can only work if hyper connections has greater than 1 residual stream'
524
+ # hyper connection
527
525
 
528
526
  init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, dim = dim, add_stream_embed = True, disable = num_residual_streams == 1)
529
527
 
@@ -560,17 +558,28 @@ class MemoryAsContextTransformer(Module):
560
558
  )
561
559
 
562
560
  mem = None
561
+ mem_qkv_layer_selector = None
563
562
  mem_hyper_conn = None
564
563
 
565
564
  if layer in neural_memory_layers:
566
- mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output, num_input_views = 3 if neural_memory_qkv_receives_diff_views else 1)
565
+ mem_hyper_conn = init_hyper_conn(add_branch_out_to_residual = not neural_mem_gate_attn_output)
566
+
567
+ if not is_first and neural_memory_qkv_receives_diff_views:
568
+ num_layer_choices = (layer - 1) * 4 + 1 # for each layer, have memory input select from attn inp, attn out, ff inp, and ff out - plus one for the current point in the residual stream (memory input)
569
+
570
+ mem_qkv_layer_selector = nn.Sequential(
571
+ nn.RMSNorm(dim),
572
+ nn.Linear(dim, 3 * num_layer_choices),
573
+ Rearrange('... (views layers) -> views ... layers', views = 3),
574
+ nn.Softmax(dim = -1)
575
+ )
567
576
 
568
577
  mem = NeuralMemory(
569
578
  dim = dim,
570
579
  chunk_size = self.neural_memory_segment_len,
571
580
  batch_size = neural_memory_batch_size,
572
581
  model = deepcopy(neural_memory_model),
573
- qkv_receives_diff_views = neural_memory_qkv_receives_diff_views,
582
+ qkv_receives_diff_views = True,
574
583
  accept_weight_residual = neural_mem_weight_residual and not is_first_neural_mem,
575
584
  **neural_memory_kwargs
576
585
  )
@@ -581,9 +590,12 @@ class MemoryAsContextTransformer(Module):
581
590
 
582
591
  self.layers.append(ModuleList([
583
592
  mem_hyper_conn,
593
+ init_hyper_conn(),
594
+ init_hyper_conn(),
595
+ mem_qkv_layer_selector,
584
596
  mem,
585
- init_hyper_conn(branch = attn),
586
- init_hyper_conn(branch = ff)
597
+ attn,
598
+ ff,
587
599
  ]))
588
600
 
589
601
  self.norm = nn.RMSNorm(dim)
@@ -763,6 +775,10 @@ class MemoryAsContextTransformer(Module):
763
775
 
764
776
  mem_weight_residual = None
765
777
 
778
+ # layers for the neural mem to select the qkv inputs from
779
+
780
+ mem_input_layers = []
781
+
766
782
  # when inferencing, only do one token at a time
767
783
 
768
784
  if is_inferencing:
@@ -773,7 +789,7 @@ class MemoryAsContextTransformer(Module):
773
789
 
774
790
  x = self.expand_streams(x)
775
791
 
776
- for mem_hyper_conn, mem, attn, ff in self.layers:
792
+ for mem_hyper_conn, attn_hyper_conn, ff_hyper_conn, mem_qkv_layer_selector, mem, attn, ff in self.layers:
777
793
 
778
794
  retrieved = None
779
795
  attn_out_gates = None
@@ -785,8 +801,19 @@ class MemoryAsContextTransformer(Module):
785
801
 
786
802
  mem_input, add_residual = mem_hyper_conn(x)
787
803
 
804
+ if not exists(mem_qkv_layer_selector):
805
+ qkv_mem_input = stack((mem_input, mem_input, mem_input))
806
+ else:
807
+ layers_to_choose_from = stack((mem_input, *mem_input_layers))
808
+
809
+ # let the current `mem_input` select the 3 layers for qkv
810
+
811
+ selected = mem_qkv_layer_selector(mem_input)
812
+
813
+ qkv_mem_input = einsum(layers_to_choose_from, selected, 'l b n d, v b n l -> v b n d')
814
+
788
815
  retrieved, next_neural_mem_cache = mem.forward(
789
- mem_input,
816
+ qkv_mem_input,
790
817
  state = next(neural_mem_caches, None),
791
818
  prev_weights = mem_weight_residual
792
819
  )
@@ -801,8 +828,12 @@ class MemoryAsContextTransformer(Module):
801
828
 
802
829
  # attention
803
830
 
804
- x, (values, next_kv_cache) = attn(
805
- x,
831
+ attn_in, add_residual = attn_hyper_conn(x)
832
+
833
+ mem_input_layers.append(attn_in)
834
+
835
+ attn_out, (values, next_kv_cache) = attn(
836
+ attn_in,
806
837
  value_residual = value_residual,
807
838
  disable_flex_attn = disable_flex_attn,
808
839
  flex_attn_fn = flex_attn_fn,
@@ -810,8 +841,12 @@ class MemoryAsContextTransformer(Module):
810
841
  cache = next(kv_caches, None)
811
842
  )
812
843
 
844
+ mem_input_layers.append(attn_out)
845
+
813
846
  value_residual = default(value_residual, values)
814
847
 
848
+ x = add_residual(attn_out)
849
+
815
850
  # caches
816
851
 
817
852
  next_kv_caches.append(next_kv_cache)
@@ -819,7 +854,15 @@ class MemoryAsContextTransformer(Module):
819
854
 
820
855
  # feedforward
821
856
 
822
- x = ff(x)
857
+ ff_in, add_ff_residual = ff_hyper_conn(x)
858
+
859
+ mem_input_layers.append(ff_in)
860
+
861
+ ff_out = ff(ff_in)
862
+
863
+ mem_input_layers.append(ff_out)
864
+
865
+ x = add_ff_residual(ff_out)
823
866
 
824
867
  # taking care of cache first
825
868
  # for early return when processing long term mem tokens during inference
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
353
353
  pred = functional_call(self.memory_model, params, inputs)
354
354
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
355
355
  weighted_loss = loss * loss_weights
356
- return weighted_loss.sum()
356
+ return weighted_loss.sum(), loss
357
357
 
358
358
  # two functions
359
359
 
360
- grad_fn = grad(forward_and_loss)
360
+ grad_fn = grad(forward_and_loss, has_aux = True)
361
361
 
362
362
  self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
363
363
 
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
526
526
  seq_index = 0,
527
527
  prev_weights = None,
528
528
  mask: Tensor | None = None,
529
+ return_surprises = True
529
530
  ):
530
531
  if self.qkv_receives_diff_views:
531
532
  _, batch, seq_len = seq.shape[:3]
@@ -645,10 +646,14 @@ class NeuralMemory(Module):
645
646
 
646
647
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
647
648
 
648
- grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
+ grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
650
 
650
651
  grads = TensorDict(grads)
651
652
 
653
+ # surprises
654
+
655
+ unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
656
+
652
657
  # maybe softclamp grad norm
653
658
 
654
659
  if exists(self.max_grad_norm):
@@ -687,7 +692,10 @@ class NeuralMemory(Module):
687
692
 
688
693
  output = (updates, next_store_state)
689
694
 
690
- return output
695
+ if not return_surprises:
696
+ return output
697
+
698
+ return (*output, unweighted_mem_model_loss)
691
699
 
692
700
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
693
701
 
@@ -744,7 +752,10 @@ class NeuralMemory(Module):
744
752
 
745
753
  # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
746
754
 
747
- return updates, next_store_state
755
+ if not return_surprises:
756
+ return updates, next_store_state
757
+
758
+ return updates, next_store_state, unweighted_mem_model_loss
748
759
 
749
760
  def retrieve_memories(
750
761
  self,
@@ -843,7 +854,8 @@ class NeuralMemory(Module):
843
854
  store_seq = None,
844
855
  state: NeuralMemState | None = None,
845
856
  prev_weights = None,
846
- store_mask: Tensor | None = None
857
+ store_mask: Tensor | None = None,
858
+ return_surprises = False
847
859
  ):
848
860
  is_multi_input = self.qkv_receives_diff_views
849
861
 
@@ -927,6 +939,7 @@ class NeuralMemory(Module):
927
939
 
928
940
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
929
941
 
942
+ surprises = None
930
943
  gate = None
931
944
 
932
945
  if exists(self.transition_gate):
@@ -937,13 +950,14 @@ class NeuralMemory(Module):
937
950
 
938
951
  # store
939
952
 
940
- next_updates, next_neural_mem_state = self.store_memories(
953
+ next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
941
954
  store_seq_chunk,
942
955
  weights,
943
956
  seq_index = seq_index,
944
957
  past_state = past_state,
945
958
  prev_weights = prev_weights,
946
- mask = maybe_store_mask
959
+ mask = maybe_store_mask,
960
+ return_surprises = True
947
961
  )
948
962
 
949
963
  weights = next_neural_mem_state.weights
@@ -952,6 +966,8 @@ class NeuralMemory(Module):
952
966
 
953
967
  updates = accum_updates(updates, next_updates)
954
968
 
969
+ surprises = safe_cat((surprises, chunk_surprises), dim = -1)
970
+
955
971
  if is_last and not update_after_final_store:
956
972
  continue
957
973
 
@@ -986,4 +1002,9 @@ class NeuralMemory(Module):
986
1002
  updates
987
1003
  )
988
1004
 
989
- return retrieved, next_neural_mem_state
1005
+ # returning
1006
+
1007
+ if not return_surprises:
1008
+ return retrieved, next_neural_mem_state
1009
+
1010
+ return retrieved, next_neural_mem_state, surprises
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.25
3
+ Version: 0.4.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
+ titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
+ titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
+ titans_pytorch/neural_memory.py,sha256=io5fvLWpOTzx8mkDA9sg3Mkc7-aeugUJoDCniryiuYE,32666
6
+ titans_pytorch-0.4.1.dist-info/METADATA,sha256=XwduHOXOJvjaWJhdYUq-1jhVq2zNKJBwMH1VWopxv5Y,6816
7
+ titans_pytorch-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.1.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
- titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
- titans_pytorch/mac_transformer.py,sha256=grD327B3OCIy7d23jNUWIoUo1bIgXUqD26dXWCjdi28,25565
4
- titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=uh5NbtAAzfPeZPFe7uhgnpUF6qyP0zjP0eXPIgY5pfc,31929
6
- titans_pytorch-0.3.25.dist-info/METADATA,sha256=SZwazbaNFe1GstoF45zI_aNMpzgXAqv4mOh78gMN5-U,6817
7
- titans_pytorch-0.3.25.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.25.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.25.dist-info/RECORD,,