titans-pytorch 0.4.4__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +6 -0
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.5.dist-info}/METADATA +1 -1
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.5.dist-info}/RECORD +5 -5
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.5.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.5.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -864,6 +864,7 @@ class NeuralMemory(Module):
|
|
864
864
|
seq,
|
865
865
|
store_seq = None,
|
866
866
|
state: NeuralMemState | None = None,
|
867
|
+
detach_mem_state = False,
|
867
868
|
prev_weights = None,
|
868
869
|
store_mask: Tensor | None = None,
|
869
870
|
return_surprises = False
|
@@ -1013,6 +1014,11 @@ class NeuralMemory(Module):
|
|
1013
1014
|
updates
|
1014
1015
|
)
|
1015
1016
|
|
1017
|
+
# maybe detach
|
1018
|
+
|
1019
|
+
if detach_mem_state:
|
1020
|
+
next_neural_mem_state = mem_state_detach(next_neural_mem_state)
|
1021
|
+
|
1016
1022
|
# returning
|
1017
1023
|
|
1018
1024
|
if not return_surprises:
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,33
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.4.
|
7
|
-
titans_pytorch-0.4.
|
8
|
-
titans_pytorch-0.4.
|
9
|
-
titans_pytorch-0.4.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=N5SzitdjxA6mkPki_xacrzdsHHaFpU8Dy5JNDRyrFtk,33309
|
6
|
+
titans_pytorch-0.4.5.dist-info/METADATA,sha256=Hju2qlfW5Su3-VdIAyy0pjhmdJFXDg4adcpR0D5UTQ8,6810
|
7
|
+
titans_pytorch-0.4.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|