titans-pytorch 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -652,13 +652,14 @@ class NeuralMemory(Module):
652
652
  next_last_update = TensorDict()
653
653
  next_last_momentum = TensorDict()
654
654
 
655
- for (param_name, surprise), (_, last_update), (_, last_momentum) in zip(surprises.items(), past_last_update.items(), past_last_momentum.items()):
655
+ for (param_name, surprise), (_, last_update) in zip(surprises.items(), past_last_update.items()):
656
656
 
657
657
  update = surprise
658
658
 
659
659
  # derive momentum with associative scan - eq (10)
660
660
 
661
661
  if has_momentum:
662
+ last_momentum = past_last_momentum[param_name]
662
663
  update = self.assoc_scan(adaptive_momentum, surprise, prev = last_momentum) # momentum is S / surprise in the paper
663
664
  momentum = update
664
665
  next_last_momentum[param_name] = momentum[:, -1]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,27
2
2
  titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
3
  titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
4
  titans_pytorch/memory_models.py,sha256=Q9SAIyAbStF5Tz0EhvRbn3yAdE3nk3xKc1ndieIe714,4671
5
- titans_pytorch/neural_memory.py,sha256=NICWfFySGeduH4-Wgn6UavWl18pddIL1JrtpT_dhXDw,27689
6
- titans_pytorch-0.3.0.dist-info/METADATA,sha256=i-tiJAahxkvzYxHOWLiwP2z8DyEPq6mFgSQ_ThvpN9A,6815
7
- titans_pytorch-0.3.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.0.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.0.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=K1z7wtv366Y6-eEyXMFZ_j7D2frPl5RxfSgxzFYoFMc,27704
6
+ titans_pytorch-0.3.1.dist-info/METADATA,sha256=ZAxucKq2DZBtW-BI_O2sUQ5RXy11a7eu48yPpwnanpw,6815
7
+ titans_pytorch-0.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.1.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.1.dist-info/RECORD,,