titans-pytorch 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/neural_memory.py +3 -2
- {titans_pytorch-0.4.5.dist-info → titans_pytorch-0.4.6.dist-info}/METADATA +1 -1
- {titans_pytorch-0.4.5.dist-info → titans_pytorch-0.4.6.dist-info}/RECORD +5 -5
- {titans_pytorch-0.4.5.dist-info → titans_pytorch-0.4.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.5.dist-info → titans_pytorch-0.4.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/neural_memory.py
CHANGED
@@ -867,7 +867,8 @@ class NeuralMemory(Module):
|
|
867
867
|
detach_mem_state = False,
|
868
868
|
prev_weights = None,
|
869
869
|
store_mask: Tensor | None = None,
|
870
|
-
return_surprises = False
|
870
|
+
return_surprises = False,
|
871
|
+
ttt_batch_size: int | None = None
|
871
872
|
):
|
872
873
|
is_multi_input = self.qkv_receives_diff_views
|
873
874
|
|
@@ -904,7 +905,7 @@ class NeuralMemory(Module):
|
|
904
905
|
# compute split sizes of sequence
|
905
906
|
# for now manually update weights to last update at the correct boundaries
|
906
907
|
|
907
|
-
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
|
908
|
+
store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size)
|
908
909
|
|
909
910
|
need_update_weights = exists(batch_size)
|
910
911
|
|
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,33
|
|
2
2
|
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
3
|
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
4
|
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=
|
6
|
-
titans_pytorch-0.4.
|
7
|
-
titans_pytorch-0.4.
|
8
|
-
titans_pytorch-0.4.
|
9
|
-
titans_pytorch-0.4.
|
5
|
+
titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
|
6
|
+
titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
|
7
|
+
titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|