titans-pytorch 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -741,10 +741,10 @@ class MemoryAsContextTransformer(Module):
741
741
 
742
742
  kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
743
743
 
744
+ next_neural_mem_cache = (seq_len, None, None, None)
744
745
  else:
745
746
  retrieved, next_neural_mem_cache = mem.forward_inference(
746
747
  mem_input,
747
- seq_index = seq_len - 1,
748
748
  state = next(neural_mem_caches, None)
749
749
  )
750
750
 
titans_pytorch/titans.py CHANGED
@@ -783,18 +783,16 @@ class NeuralMemory(Module):
783
783
  def forward_inference(
784
784
  self,
785
785
  token: Tensor,
786
- seq_index = None, # the index of the token in the sequence, starts at 0
787
786
  state = None,
788
787
  ):
789
788
 
790
789
  # unpack previous state
791
790
 
792
791
  if not exists(state):
793
- state = (None, None, None)
792
+ state = (0, None, None, None)
794
793
 
795
- cache_store_seq, past_states, updates = state
794
+ seq_index, cache_store_seq, past_states, updates = state
796
795
 
797
- seq_index = default(seq_index, 0)
798
796
  curr_seq_len = seq_index + 1
799
797
  batch = token.shape[0]
800
798
 
@@ -814,7 +812,7 @@ class NeuralMemory(Module):
814
812
  if curr_seq_len < self.chunk_size:
815
813
  empty_mem = self.init_empty_memory_embed(batch, 1)
816
814
 
817
- return empty_mem, (cache_store_seq, past_states, updates)
815
+ return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
818
816
 
819
817
  # store if storage sequence cache hits the chunk size
820
818
 
@@ -842,7 +840,7 @@ class NeuralMemory(Module):
842
840
 
843
841
  # next state tuple
844
842
 
845
- next_state = (cache_store_seq, next_states, updates)
843
+ next_state = (curr_seq_len, cache_store_seq, next_states, updates)
846
844
 
847
845
  return retrieved, next_state
848
846
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.1.26
3
+ Version: 0.1.27
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
35
35
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
36
36
  Requires-Python: >=3.9
37
37
  Requires-Dist: accelerated-scan>=0.2.0
38
- Requires-Dist: axial-positional-embedding>=0.3.9
38
+ Requires-Dist: axial-positional-embedding>=0.3.10
39
39
  Requires-Dist: einops>=0.8.0
40
40
  Requires-Dist: einx>=0.3.0
41
41
  Requires-Dist: hyper-connections>=0.1.9
@@ -0,0 +1,8 @@
1
+ titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/mac_transformer.py,sha256=Staf9hRQ44QAL23bSGh4VSB8NeGtMri-JdiZdgirJiU,23587
4
+ titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
5
+ titans_pytorch-0.1.27.dist-info/METADATA,sha256=AZ5-_d9o_khm6jaky1zoKyXB1hDQNifbS061v_b4McQ,6815
6
+ titans_pytorch-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
+ titans_pytorch-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
+ titans_pytorch-0.1.27.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/mac_transformer.py,sha256=RkEGmVlQyK1opqylqt1VEFEc_Gd_pbAArcwfhphotXI,23564
4
- titans_pytorch/titans.py,sha256=a-BXTG6DdNXWhby6E4W2fdhwipuMQ12tSqSL10iLvfY,25826
5
- titans_pytorch-0.1.26.dist-info/METADATA,sha256=zogTDD7iLlxkPDzIeCap9GCgz2VNFUWjVF_K6K8H9kg,6814
6
- titans_pytorch-0.1.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
7
- titans_pytorch-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
8
- titans_pytorch-0.1.26.dist-info/RECORD,,