titans-pytorch 0.4.1__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.4.1"
3
+ version = "0.4.2"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -85,9 +85,9 @@ def test_return_surprises():
85
85
 
86
86
  seq = torch.randn(4, 64, 384)
87
87
 
88
- _, _, surprises = mem(seq, return_surprises = True)
88
+ _, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
89
89
 
90
- assert surprises.shape == (4, 4, 64)
90
+ assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
91
91
 
92
92
  @pytest.mark.parametrize('learned_momentum_combine', (False, True))
93
93
  @pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
@@ -652,6 +652,7 @@ class NeuralMemory(Module):
652
652
 
653
653
  # surprises
654
654
 
655
+ adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
655
656
  unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
656
657
 
657
658
  # maybe softclamp grad norm
@@ -695,7 +696,7 @@ class NeuralMemory(Module):
695
696
  if not return_surprises:
696
697
  return output
697
698
 
698
- return (*output, unweighted_mem_model_loss)
699
+ return (*output, (unweighted_mem_model_loss, adaptive_lr))
699
700
 
700
701
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
701
702
 
@@ -755,7 +756,7 @@ class NeuralMemory(Module):
755
756
  if not return_surprises:
756
757
  return updates, next_store_state
757
758
 
758
- return updates, next_store_state, unweighted_mem_model_loss
759
+ return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
759
760
 
760
761
  def retrieve_memories(
761
762
  self,
File without changes
File without changes
File without changes
File without changes