titans-pytorch 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +5 -4
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.3.dist-info}/METADATA +1 -1
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.3.dist-info}/RECORD +5 -5
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.3.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.1.dist-info → titans_pytorch-0.4.3.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -652,6 +652,7 @@ class NeuralMemory(Module):
|
|
652
652
|
|
653
653
|
# surprises
|
654
654
|
|
655
|
+
adaptive_lr = rearrange(adaptive_lr, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
655
656
|
unweighted_mem_model_loss = rearrange(unweighted_mem_model_loss, '(b h n) c -> b h (n c)', b = batch, h = heads)
|
656
657
|
|
657
658
|
# maybe softclamp grad norm
|
@@ -695,7 +696,7 @@ class NeuralMemory(Module):
|
|
695
696
|
if not return_surprises:
|
696
697
|
return output
|
697
698
|
|
698
|
-
return (*output, unweighted_mem_model_loss)
|
699
|
+
return (*output, (unweighted_mem_model_loss, adaptive_lr))
|
699
700
|
|
700
701
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
701
702
|
|
@@ -755,7 +756,7 @@ class NeuralMemory(Module):
|
|
755
756
|
if not return_surprises:
|
756
757
|
return updates, next_store_state
|
757
758
|
|
758
|
-
return updates, next_store_state, unweighted_mem_model_loss
|
759
|
+
return updates, next_store_state, (unweighted_mem_model_loss, adaptive_lr)
|
759
760
|
|
760
761
|
def retrieve_memories(
|
761
762
|
self,
|
@@ -939,7 +940,7 @@ class NeuralMemory(Module):
|
|
939
940
|
|
940
941
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
941
942
|
|
942
|
-
surprises = None
|
943
|
+
surprises = (None, None)
|
943
944
|
gate = None
|
944
945
|
|
945
946
|
if exists(self.transition_gate):
|
@@ -966,7 +967,7 @@ class NeuralMemory(Module):
|
|
966
967
|
|
967
968
|
updates = accum_updates(updates, next_updates)
|
968
969
|
|
969
|
-
surprises = safe_cat(
|
970
|
+
surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
|
970
971
|
|
971
972
|
if is_last and not update_after_final_store:
|
972
973
|
continue
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,29
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.4.
|
7
|
-
titans_pytorch-0.4.
|
8
|
-
titans_pytorch-0.4.
|
9
|
-
titans_pytorch-0.4.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=HdBaRGURJ84Qy-a6PdfeQoc5ZzY7H0c5YHUASaSVu1A,32824
|
6
|
+
titans_pytorch-0.4.3.dist-info/METADATA,sha256=SIq5KS2xehsUAwuFpRSFNdnLbgamWUMLN5xj4MJGRe0,6816
|
7
|
+
titans_pytorch-0.4.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|