titans-pytorch 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -864,9 +864,11 @@ class NeuralMemory(Module):
864
864
  seq,
865
865
  store_seq = None,
866
866
  state: NeuralMemState | None = None,
867
+ detach_mem_state = False,
867
868
  prev_weights = None,
868
869
  store_mask: Tensor | None = None,
869
- return_surprises = False
870
+ return_surprises = False,
871
+ ttt_batch_size: int | None = None
870
872
  ):
871
873
  is_multi_input = self.qkv_receives_diff_views
872
874
 
@@ -903,7 +905,7 @@ class NeuralMemory(Module):
903
905
  # compute split sizes of sequence
904
906
  # for now manually update weights to last update at the correct boundaries
905
907
 
906
- store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, self.batch_size
908
+ store_seq_len, chunk_size, batch_size = store_seq.shape[-2], self.chunk_size, default(ttt_batch_size, self.batch_size)
907
909
 
908
910
  need_update_weights = exists(batch_size)
909
911
 
@@ -1013,6 +1015,11 @@ class NeuralMemory(Module):
1013
1015
  updates
1014
1016
  )
1015
1017
 
1018
+ # maybe detach
1019
+
1020
+ if detach_mem_state:
1021
+ next_neural_mem_state = mem_state_detach(next_neural_mem_state)
1022
+
1016
1023
  # returning
1017
1024
 
1018
1025
  if not return_surprises:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.4.4
3
+ Version: 0.4.6
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -2,8 +2,8 @@ titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,33
2
2
  titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
3
3
  titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
4
4
  titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
5
- titans_pytorch/neural_memory.py,sha256=eDC8rl8S241zp8pxKrIEi_6nFm6CEoaH9K4hnDfgzu8,33145
6
- titans_pytorch-0.4.4.dist-info/METADATA,sha256=CWciTl1VeOvwyL_lqr0JsdmDDIjXG85N8ykwd3w2TxQ,6810
7
- titans_pytorch-0.4.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.4.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.4.4.dist-info/RECORD,,
5
+ titans_pytorch/neural_memory.py,sha256=CsB-Wd3T_N_HluRLjjwB2gPdlLymJmmhuxiRJhQuXkA,33377
6
+ titans_pytorch-0.4.6.dist-info/METADATA,sha256=aWhQMQrjBLzUPmvtH-LY47r4ayz_ts70gGyDMMyJ6Sc,6810
7
+ titans_pytorch-0.4.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.4.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.4.6.dist-info/RECORD,,