titans-pytorch 0.4.5__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -867,7 +867,8 @@ class NeuralMemory(Module):
867
867
  detach_mem_state = False,
868
868
  prev_weights = None,
869
869
  store_mask: Tensor | None = None,
870
- return_surprises = False
870
+ return_surprises = False,
871
+ ttt_batch_size: int | None = None
871
872
  ):
872
873
  is_multi_input = self.qkv_receives_diff_views
873
874
 
@@ -904,7 +905,7 @@ class NeuralMemory(Module):
904
905
  # compute split sizes of sequence
905
906
  # for now manually update weights to last update at the correct boundaries
906
907
 
907
- store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
908
+ store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size)
908
909
 
909
910
  need_update_weights = exists(batch_size)
910
911
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.5
3
+ Version: 0.4.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,33
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
3
  titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=N5SzitdjxA6mkPki_xacrzdsHHaFpU8Dy5JNDRyrFtk,33309
6
- titans_pytorch-0.4.5.dist-info/METADATA,sha256=Hju2qlfW5Su3-VdIAyy0pjhmdJFXDg4adcpR0D5UTQ8,6810
7
- titans_pytorch-0.4.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.5.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
6
+ titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
7
+ titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.6.dist-info/RECORD,,