titans-pytorch 0.4.0__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -353,11 +353,11 @@ class NeuralMemory(Module):
353
353
  pred = functional_call(self.memory_model, params, inputs)
354
354
  loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
355
355
  weighted_loss = loss * loss_weights
356
- return weighted_loss.sum()
356
+ return weighted_loss.sum(), loss
357
357
 
358
358
  # two functions
359
359
 
360
- grad_fn = grad(forward_and_loss)
360
+ grad_fn = grad(forward_and_loss, has_aux = True)
361
361
 
362
362
  self.per_sample_grad_fn = vmap(grad_fn, in_dims = (0, 0, 0, 0))
363
363
 
@@ -526,6 +526,7 @@ class NeuralMemory(Module):
526
526
  seq_index = 0,
527
527
  prev_weights = None,
528
528
  mask: Tensor | None = None,
529
+ return_surprises = True
529
530
  ):
530
531
  if self.qkv_receives_diff_views:
531
532
  _, batch, seq_len = seq.shape[:3]
@@ -645,10 +646,14 @@ class NeuralMemory(Module):
645
646
 
646
647
  # get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
647
648
 
648
- grads = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
+ grads, unweighted_mem_model_loss = self.per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
649
650
 
650
651
  grads = TensorDict(grads)
651
652
 
653
+ # surprises
654
+
655
+ unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
656
+
652
657
  # maybe softclamp grad norm
653
658
 
654
659
  if exists(self.max_grad_norm):
@@ -687,7 +692,10 @@ class NeuralMemory(Module):
687
692
 
688
693
  output = (updates, next_store_state)
689
694
 
690
- return output
695
+ if not return_surprises:
696
+ return output
697
+
698
+ return (*output, unweighted_mem_model_loss)
691
699
 
692
700
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
693
701
 
@@ -744,7 +752,10 @@ class NeuralMemory(Module):
744
752
 
745
753
  # return updates to neural memory at all chunked timesteps + neural mem cache / state to be fed back
746
754
 
747
- return updates, next_store_state
755
+ if not return_surprises:
756
+ return updates, next_store_state
757
+
758
+ return updates, next_store_state, unweighted_mem_model_loss
748
759
 
749
760
  def retrieve_memories(
750
761
  self,
@@ -843,7 +854,8 @@ class NeuralMemory(Module):
843
854
  store_seq = None,
844
855
  state: NeuralMemState | None = None,
845
856
  prev_weights = None,
846
- store_mask: Tensor | None = None
857
+ store_mask: Tensor | None = None,
858
+ return_surprises = False
847
859
  ):
848
860
  is_multi_input = self.qkv_receives_diff_views
849
861
 
@@ -927,6 +939,7 @@ class NeuralMemory(Module):
927
939
 
928
940
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
929
941
 
942
+ surprises = None
930
943
  gate = None
931
944
 
932
945
  if exists(self.transition_gate):
@@ -937,13 +950,14 @@ class NeuralMemory(Module):
937
950
 
938
951
  # store
939
952
 
940
- next_updates, next_neural_mem_state = self.store_memories(
953
+ next_updates, next_neural_mem_state, chunk_surprises = self.store_memories(
941
954
  store_seq_chunk,
942
955
  weights,
943
956
  seq_index = seq_index,
944
957
  past_state = past_state,
945
958
  prev_weights = prev_weights,
946
- mask = maybe_store_mask
959
+ mask = maybe_store_mask,
960
+ return_surprises = True
947
961
  )
948
962
 
949
963
  weights = next_neural_mem_state.weights
@@ -952,6 +966,8 @@ class NeuralMemory(Module):
952
966
 
953
967
  updates = accum_updates(updates, next_updates)
954
968
 
969
+ surprises = safe_cat((surprises, chunk_surprises), dim = -1)
970
+
955
971
  if is_last and not update_after_final_store:
956
972
  continue
957
973
 
@@ -986,4 +1002,9 @@ class NeuralMemory(Module):
986
1002
  updates
987
1003
  )
988
1004
 
989
- return retrieved, next_neural_mem_state
1005
+ # returning
1006
+
1007
+ if not return_surprises:
1008
+ return retrieved, next_neural_mem_state
1009
+
1010
+ return retrieved, next_neural_mem_state, surprises
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
3
  titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=uh5NbtAAzfPeZPFe7uhgnpUF6qyP0zjP0eXPIgY5pfc,31929
6
- titans_pytorch-0.4.0.dist-info/METADATA,sha256=uOklaPv-y-eSpgnvrgVZ-ZL4TpeBg7r_EJxwJbdKyO0,6816
7
- titans_pytorch-0.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.0.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=io5fvLWpOTzx8mkDA9sg3Mkc7-aeugUJoDCniryiuYE,32666
6
+ titans_pytorch-0.4.1.dist-info/METADATA,sha256=XwduHOXOJvjaWJhdYUq-1jhVq2zNKJBwMH1VWopxv5Y,6816
7
+ titans_pytorch-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.1.dist-info/RECORD,,