titans-pytorch 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.4.0"
3
+ version = "0.4.2"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -74,6 +74,21 @@ def test_titans(
74
74
 
75
75
  assert seq.shape == retrieved.shape
76
76
 
77
+ def test_return_surprises():
78
+
79
+ mem = NeuralMemory(
80
+ dim = 384,
81
+ chunk_size = 2,
82
+ dim_head = 64,
83
+ heads = 4,
84
+ )
85
+
86
+ seq = torch.randn(4, 64, 384)
87
+
88
+ _, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
89
+
90
+ assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
91
+
77
92
  @pytest.mark.parametrize('learned_momentum_combine', (False, True))
78
93
  @pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
79
94
  def test_titans_second_order_momentum(
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
353
353
  pred = functional_call(self.memory_model, params, inputs)
354
354
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
355
355
  weighted_loss = loss * loss_weights
356
- return weighted_loss.sum()
356
+ return weighted_loss.sum(), loss
357
357
 
358
358
  # two functions
359
359
 
360
- grad_fn = grad(forward_and_loss)
360
+ grad_fn = grad(forward_and_loss, has_aux = True)
361
361
 
362
362
  self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
363
363
 
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
526
526
  seq_index = 0,
527
527
  prev_weights = None,
528
528
  mask: Tensor | None = None,
529
+ return_surprises = True
529
530
  ):
530
531
  if self.qkv_receives_diff_views:
531
532
  _, batch, seq_len = seq.shape[:3]
@@ -645,10 +646,15 @@ class NeuralMemory(Module):
645
646
 
646
647
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
647
648
 
648
- grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
+ grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
650
 
650
651
  grads = TensorDict(grads)
651
652
 
653
+ # surprises
654
+
655
+ adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
656
+ unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
657
+
652
658
  # maybe softclamp grad norm
653
659
 
654
660
  if exists(self.max_grad_norm):
@@ -687,7 +693,10 @@ class NeuralMemory(Module):
687
693
 
688
694
  output = (updates, next_store_state)
689
695
 
690
- return output
696
+ if not return_surprises:
697
+ return output
698
+
699
+ return (*output, (unweighted_mem_model_loss, adaptive_lr))
691
700
 
692
701
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
693
702
 
@@ -744,7 +753,10 @@ class NeuralMemory(Module):
744
753
 
745
754
  # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
746
755
 
747
- return updates, next_store_state
756
+ if not return_surprises:
757
+ return updates, next_store_state
758
+
759
+ return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
748
760
 
749
761
  def retrieve_memories(
750
762
  self,
@@ -843,7 +855,8 @@ class NeuralMemory(Module):
843
855
  store_seq = None,
844
856
  state: NeuralMemState | None = None,
845
857
  prev_weights = None,
846
- store_mask: Tensor | None = None
858
+ store_mask: Tensor | None = None,
859
+ return_surprises = False
847
860
  ):
848
861
  is_multi_input = self.qkv_receives_diff_views
849
862
 
@@ -927,6 +940,7 @@ class NeuralMemory(Module):
927
940
 
928
941
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
929
942
 
943
+ surprises = None
930
944
  gate = None
931
945
 
932
946
  if exists(self.transition_gate):
@@ -937,13 +951,14 @@ class NeuralMemory(Module):
937
951
 
938
952
  # store
939
953
 
940
- next_updates, next_neural_mem_state = self.store_memories(
954
+ next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
941
955
  store_seq_chunk,
942
956
  weights,
943
957
  seq_index = seq_index,
944
958
  past_state = past_state,
945
959
  prev_weights = prev_weights,
946
- mask = maybe_store_mask
960
+ mask = maybe_store_mask,
961
+ return_surprises = True
947
962
  )
948
963
 
949
964
  weights = next_neural_mem_state.weights
@@ -952,6 +967,8 @@ class NeuralMemory(Module):
952
967
 
953
968
  updates = accum_updates(updates, next_updates)
954
969
 
970
+ surprises = safe_cat((surprises, chunk_surprises), dim = -1)
971
+
955
972
  if is_last and not update_after_final_store:
956
973
  continue
957
974
 
@@ -986,4 +1003,9 @@ class NeuralMemory(Module):
986
1003
  updates
987
1004
  )
988
1005
 
989
- return retrieved, next_neural_mem_state
1006
+ # returning
1007
+
1008
+ if not return_surprises:
1009
+ return retrieved, next_neural_mem_state
1010
+
1011
+ return retrieved, next_neural_mem_state, surprises
File without changes
File without changes
File without changes
File without changes