titans-pytorch 0.1.22__py3-none-any.whl → 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +67 -9
- {titans_pytorch-0.1.22.dist-info → titans_pytorch-0.1.23.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.23.dist-info/RECORD +8 -0
- titans_pytorch-0.1.22.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.22.dist-info → titans_pytorch-0.1.23.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.22.dist-info → titans_pytorch-0.1.23.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -44,6 +44,16 @@ def default(v, d):
|
|
44
44
|
def xnor(x, y):
|
45
45
|
return not (x ^ y)
|
46
46
|
|
47
|
+
def safe_cat(inputs, dim = -2):
|
48
|
+
inputs = tuple(filter(exists, inputs))
|
49
|
+
|
50
|
+
if len(inputs) == 0:
|
51
|
+
return None
|
52
|
+
elif len(inputs) == 1:
|
53
|
+
return inputs[0]
|
54
|
+
|
55
|
+
return cat(inputs, dim = dim)
|
56
|
+
|
47
57
|
def identity(t):
|
48
58
|
return t
|
49
59
|
|
@@ -314,7 +324,11 @@ class AssocScan(Module):
|
|
314
324
|
super().__init__()
|
315
325
|
self.use_accelerated = use_accelerated
|
316
326
|
|
317
|
-
def forward(self, gates, inputs):
|
327
|
+
def forward(self, gates, inputs, prev = None):
|
328
|
+
|
329
|
+
if exists(prev):
|
330
|
+
inputs, _ = pack([prev, inputs], 'b * d')
|
331
|
+
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
318
332
|
|
319
333
|
if not self.use_accelerated:
|
320
334
|
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
@@ -678,7 +692,7 @@ class NeuralMemory(Module):
|
|
678
692
|
def retrieve_memories(
|
679
693
|
self,
|
680
694
|
seq,
|
681
|
-
past_weights: dict[str, Tensor]
|
695
|
+
past_weights: dict[str, Tensor],
|
682
696
|
chunk_size = None
|
683
697
|
):
|
684
698
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
@@ -700,13 +714,7 @@ class NeuralMemory(Module):
|
|
700
714
|
# the parameters of the memory model stores the memories of the key / values
|
701
715
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
702
716
|
|
703
|
-
curr_weights = TensorDict(
|
704
|
-
|
705
|
-
if exists(past_weights):
|
706
|
-
past_weights = TensorDict(past_weights)
|
707
|
-
assert past_weights.keys() == curr_weights.keys()
|
708
|
-
|
709
|
-
curr_weights = curr_weights + past_weights
|
717
|
+
curr_weights = TensorDict(past_weights)
|
710
718
|
|
711
719
|
# sequence Float['b n d'] to queries
|
712
720
|
|
@@ -753,6 +761,56 @@ class NeuralMemory(Module):
|
|
753
761
|
|
754
762
|
return values[:, :seq_len]
|
755
763
|
|
764
|
+
def forward_inference(
|
765
|
+
self,
|
766
|
+
token: Tensor,
|
767
|
+
seq_index = None, # the index of the token in the sequence, starts at 0
|
768
|
+
mem_model_state = None,
|
769
|
+
cache_store_seq = None
|
770
|
+
):
|
771
|
+
seq_index = default(seq_index, 0)
|
772
|
+
curr_seq_len = seq_index + 1
|
773
|
+
batch = token.shape[0]
|
774
|
+
|
775
|
+
if token.ndim == 2:
|
776
|
+
token = rearrange(token, 'b d -> b 1 d')
|
777
|
+
|
778
|
+
# init memory model if needed
|
779
|
+
|
780
|
+
if not exists(mem_model_state):
|
781
|
+
mem_model_state = self.init_weights_and_momentum()
|
782
|
+
|
783
|
+
# increment the sequence cache which is at most the chunk size
|
784
|
+
|
785
|
+
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
786
|
+
|
787
|
+
# early return empty memory, when no memories are stored for steps < first chunk size
|
788
|
+
|
789
|
+
if curr_seq_len < self.chunk_size:
|
790
|
+
empty_mem = self.init_empty_memory_embed(batch, 1)
|
791
|
+
|
792
|
+
return empty_mem, cache_store_seq, mem_model_state
|
793
|
+
|
794
|
+
# store if storage sequence cache hits the chunk size
|
795
|
+
|
796
|
+
store_seq_cache_len = cache_store_seq.shape[-2]
|
797
|
+
|
798
|
+
if store_seq_cache_len == self.chunk_size:
|
799
|
+
updates, _ = self.store_memories(cache_store_seq, mem_model_state)
|
800
|
+
|
801
|
+
past_weights, past_momentum = mem_model_state
|
802
|
+
mem_model_state = (past_weights + updates, past_momentum)
|
803
|
+
|
804
|
+
cache_store_seq = None
|
805
|
+
|
806
|
+
# retrieve
|
807
|
+
|
808
|
+
past_weights, _ = mem_model_state
|
809
|
+
|
810
|
+
retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
|
811
|
+
|
812
|
+
return retrieved, cache_store_seq, mem_model_state
|
813
|
+
|
756
814
|
def forward(
|
757
815
|
self,
|
758
816
|
seq,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.23
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -78,7 +78,7 @@ from titans_pytorch import NeuralMemory
|
|
78
78
|
|
79
79
|
mem = NeuralMemory(
|
80
80
|
dim = 384,
|
81
|
-
chunk_size = 64
|
81
|
+
chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
|
82
82
|
).cuda()
|
83
83
|
|
84
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
+
titans_pytorch/titans.py,sha256=WbagKMYDs-3NoW2j_pAyHEnvR9QzH3A9WntHuV_FKOo,25109
|
5
|
+
titans_pytorch-0.1.23.dist-info/METADATA,sha256=H7QbLscawNObHGeoTbnKbf-NOqkMqWCu4yWeZJ0yKMA,6814
|
6
|
+
titans_pytorch-0.1.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.23.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.23.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
-
titans_pytorch/titans.py,sha256=7PGnZxdKq6T6e51RL7-QV43wp-46YmrytTZLt0McMco,23407
|
5
|
-
titans_pytorch-0.1.22.dist-info/METADATA,sha256=HCOAqLK605c8R2mvgQ80kwE9jayZ2CwJqHLsJtFx7Vs,6718
|
6
|
-
titans_pytorch-0.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|