titans-pytorch 0.4.1__tar.gz → 0.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.4.1"
3
+ version = "0.4.3"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -85,9 +85,9 @@ def test_return_surprises():
85
85
 
86
86
  seq = torch.randn(4, 64, 384)
87
87
 
88
- _, _, surprises = mem(seq, return_surprises = True)
88
+ _, _, (surprises, adaptive_lr) = mem(seq, return_surprises = True)
89
89
 
90
- assert surprises.shape == (4, 4, 64)
90
+ assert all([t.shape == (4, 4, 64) for t in (surprises, adaptive_lr)])
91
91
 
92
92
  @pytest.mark.parametrize('learned_momentum_combine', (False, True))
93
93
  @pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
@@ -652,6 +652,7 @@ class NeuralMemory(Module):
652
652
 
653
653
  # surprises
654
654
 
655
+ adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
655
656
  unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
656
657
 
657
658
  # maybe softclamp grad norm
@@ -695,7 +696,7 @@ class NeuralMemory(Module):
695
696
  if not return_surprises:
696
697
  return output
697
698
 
698
- return (*output, unweighted_mem_model_loss)
699
+ return (*output, (unweighted_mem_model_loss, adaptive_lr))
699
700
 
700
701
  # momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
701
702
 
@@ -755,7 +756,7 @@ class NeuralMemory(Module):
755
756
  if not return_surprises:
756
757
  return updates, next_store_state
757
758
 
758
- return updates, next_store_state, unweighted_mem_model_loss
759
+ return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
759
760
 
760
761
  def retrieve_memories(
761
762
  self,
@@ -939,7 +940,7 @@ class NeuralMemory(Module):
939
940
 
940
941
  # whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
941
942
 
942
- surprises = None
943
+ surprises = (None, None)
943
944
  gate = None
944
945
 
945
946
  if exists(self.transition_gate):
@@ -966,7 +967,7 @@ class NeuralMemory(Module):
966
967
 
967
968
  updates = accum_updates(updates, next_updates)
968
969
 
969
- surprises = safe_cat((surprises, chunk_surprises), dim = -1)
970
+ surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
970
971
 
971
972
  if is_last and not update_after_final_store:
972
973
  continue
File without changes
File without changes
File without changes
File without changes