titans-pytorch 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +9 -2
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.6.dist-info}/METADATA +1 -1
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.6.dist-info}/RECORD +5 -5
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.4.dist-info → titans_pytorch-0.4.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -864,9 +864,11 @@ class NeuralMemory(Module):
|
|
864
864
|
seq,
|
865
865
|
store_seq = None,
|
866
866
|
state: NeuralMemState | None = None,
|
867
|
+
detach_mem_state = False,
|
867
868
|
prev_weights = None,
|
868
869
|
store_mask: Tensor | None = None,
|
869
|
-
return_surprises = False
|
870
|
+
return_surprises = False,
|
871
|
+
ttt_batch_size: int | None = None
|
870
872
|
):
|
871
873
|
is_multi_input = self.qkv_receives_diff_views
|
872
874
|
|
@@ -903,7 +905,7 @@ class NeuralMemory(Module):
|
|
903
905
|
# compute split sizes of sequence
|
904
906
|
# for now manually update weights to last update at the correct boundaries
|
905
907
|
|
906
|
-
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
|
908
|
+
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size)
|
907
909
|
|
908
910
|
need_update_weights = exists(batch_size)
|
909
911
|
|
@@ -1013,6 +1015,11 @@ class NeuralMemory(Module):
|
|
1013
1015
|
updates
|
1014
1016
|
)
|
1015
1017
|
|
1018
|
+
# maybe detach
|
1019
|
+
|
1020
|
+
if detach_mem_state:
|
1021
|
+
next_neural_mem_state = mem_state_detach(next_neural_mem_state)
|
1022
|
+
|
1016
1023
|
# returning
|
1017
1024
|
|
1018
1025
|
if not return_surprises:
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,33
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.4.
|
7
|
-
titans_pytorch-0.4.
|
8
|
-
titans_pytorch-0.4.
|
9
|
-
titans_pytorch-0.4.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
|
6
|
+
titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
|
7
|
+
titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|