titans-pytorch 0.1.26__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +1 -1
- titans_pytorch/titans.py +4 -6
- {titans_pytorch-0.1.26.dist-info → titans_pytorch-0.1.27.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.27.dist-info/RECORD +8 -0
- titans_pytorch-0.1.26.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.26.dist-info → titans_pytorch-0.1.27.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.26.dist-info → titans_pytorch-0.1.27.dist-info}/licenses/LICENSE +0 -0
@@ -741,10 +741,10 @@ class MemoryAsContextTransformer(Module):
|
|
741
741
|
|
742
742
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
743
743
|
|
744
|
+
next_neural_mem_cache = (seq_len, None, None, None)
|
744
745
|
else:
|
745
746
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
746
747
|
mem_input,
|
747
|
-
seq_index = seq_len - 1,
|
748
748
|
state = next(neural_mem_caches, None)
|
749
749
|
)
|
750
750
|
|
titans_pytorch/titans.py
CHANGED
@@ -783,18 +783,16 @@ class NeuralMemory(Module):
|
|
783
783
|
def forward_inference(
|
784
784
|
self,
|
785
785
|
token: Tensor,
|
786
|
-
seq_index = None, # the index of the token in the sequence, starts at 0
|
787
786
|
state = None,
|
788
787
|
):
|
789
788
|
|
790
789
|
# unpack previous state
|
791
790
|
|
792
791
|
if not exists(state):
|
793
|
-
state = (None, None, None)
|
792
|
+
state = (0, None, None, None)
|
794
793
|
|
795
|
-
cache_store_seq, past_states, updates = state
|
794
|
+
seq_index, cache_store_seq, past_states, updates = state
|
796
795
|
|
797
|
-
seq_index = default(seq_index, 0)
|
798
796
|
curr_seq_len = seq_index + 1
|
799
797
|
batch = token.shape[0]
|
800
798
|
|
@@ -814,7 +812,7 @@ class NeuralMemory(Module):
|
|
814
812
|
if curr_seq_len < self.chunk_size:
|
815
813
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
816
814
|
|
817
|
-
return empty_mem, (cache_store_seq, past_states, updates)
|
815
|
+
return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
|
818
816
|
|
819
817
|
# store if storage sequence cache hits the chunk size
|
820
818
|
|
@@ -842,7 +840,7 @@ class NeuralMemory(Module):
|
|
842
840
|
|
843
841
|
# next state tuple
|
844
842
|
|
845
|
-
next_state = (cache_store_seq, next_states, updates)
|
843
|
+
next_state = (curr_seq_len, cache_store_seq, next_states, updates)
|
846
844
|
|
847
845
|
return retrieved, next_state
|
848
846
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.27
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,7 +35,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
-
Requires-Dist: axial-positional-embedding>=0.3.
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.10
|
39
39
|
Requires-Dist: einops>=0.8.0
|
40
40
|
Requires-Dist: einx>=0.3.0
|
41
41
|
Requires-Dist: hyper-connections>=0.1.9
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Staf9hRQ44QAL23bSGh4VSB8NeGtMri-JdiZdgirJiU,23587
|
4
|
+
titans_pytorch/titans.py,sha256=gjoDcTsvw5X2d1I2xq4cM45YJIBqtLFuws8_jVylW_4,25746
|
5
|
+
titans_pytorch-0.1.27.dist-info/METADATA,sha256=AZ5-_d9o_khm6jaky1zoKyXB1hDQNifbS061v_b4McQ,6815
|
6
|
+
titans_pytorch-0.1.27.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.27.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.27.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RkEGmVlQyK1opqylqt1VEFEc_Gd_pbAArcwfhphotXI,23564
|
4
|
-
titans_pytorch/titans.py,sha256=a-BXTG6DdNXWhby6E4W2fdhwipuMQ12tSqSL10iLvfY,25826
|
5
|
-
titans_pytorch-0.1.26.dist-info/METADATA,sha256=zogTDD7iLlxkPDzIeCap9GCgz2VNFUWjVF_K6K8H9kg,6814
|
6
|
-
titans_pytorch-0.1.26.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.26.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.26.dist-info/RECORD,,
|
File without changes
|
File without changes
|