titans-pytorch 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +3 -2
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.2.dist-info}/METADATA +1 -1
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.2.dist-info}/RECORD +5 -5
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.2.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.2.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -652,6 +652,7 @@ class NeuralMemory(Module):
|
|
652
652
|
|
653
653
|
# surprises
|
654
654
|
|
655
|
+
adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
655
656
|
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
656
657
|
|
657
658
|
# maybe softclamp grad norm
|
@@ -695,7 +696,7 @@ class NeuralMemory(Module):
|
|
695
696
|
if not return_surprises:
|
696
697
|
return output
|
697
698
|
|
698
|
-
return (*output, unweighted_mem_model_loss)
|
699
|
+
return (*output, (unweighted_mem_model_loss, adaptive_lr))
|
699
700
|
|
700
701
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
701
702
|
|
@@ -755,7 +756,7 @@ class NeuralMemory(Module):
|
|
755
756
|
if not return_surprises:
|
756
757
|
return updates, next_store_state
|
757
758
|
|
758
|
-
return updates, next_store_state, unweighted_mem_model_loss
|
759
|
+
return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
|
759
760
|
|
760
761
|
def retrieve_memories(
|
761
762
|
self,
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.4.
|
7
|
-
titans_pytorch-0.4.
|
8
|
-
titans_pytorch-0.4.
|
9
|
-
titans_pytorch-0.4.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=D7jzi2SjcVj89F3Ws-zyOp04mCg5sJuUFXC6GPRdiz8,32789
|
6
|
+
titans_pytorch-0.4.2.dist-info/METADATA,sha256=HNJZM3kvMlnRLVN9i4hLecWSL93q0Fg7nqq8xz-BT2o,6816
|
7
|
+
titans_pytorch-0.4.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.2.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|