titans-pytorch 0.1.22__tar.gz → 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/PKG-INFO +2 -2
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/README.md +1 -1
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/pyproject.toml +1 -1
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/tests/test_titans.py +56 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/titans_pytorch/titans.py +67 -9
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/.gitignore +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/LICENSE +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/data/README.md +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/fig1.png +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/fig2.png +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.1.22 → titans_pytorch-0.1.23}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.23
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -78,7 +78,7 @@ from titans_pytorch import NeuralMemory
|
|
78
78
|
|
79
79
|
mem = NeuralMemory(
|
80
80
|
dim = 384,
|
81
|
-
chunk_size = 64
|
81
|
+
chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
|
82
82
|
).cuda()
|
83
83
|
|
84
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -24,7 +24,7 @@ from titans_pytorch import NeuralMemory
|
|
24
24
|
|
25
25
|
mem = NeuralMemory(
|
26
26
|
dim = 384,
|
27
|
-
chunk_size = 64
|
27
|
+
chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
|
28
28
|
).cuda()
|
29
29
|
|
30
30
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -8,6 +8,9 @@ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, M
|
|
8
8
|
def exists(v):
|
9
9
|
return v is not None
|
10
10
|
|
11
|
+
def diff(x, y):
|
12
|
+
return (x - y).abs().amax()
|
13
|
+
|
11
14
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
12
15
|
@pytest.mark.parametrize('silu', (False, True))
|
13
16
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
@@ -133,6 +136,39 @@ def test_mac_sampling(sliding):
|
|
133
136
|
|
134
137
|
assert torch.allclose(sampled, sampled_with_cache)
|
135
138
|
|
139
|
+
@pytest.mark.parametrize('seq_len', (2, 64))
|
140
|
+
def test_neural_mem_inference(
|
141
|
+
seq_len
|
142
|
+
):
|
143
|
+
mem = NeuralMemory(
|
144
|
+
dim = 384,
|
145
|
+
chunk_size = 64,
|
146
|
+
)
|
147
|
+
|
148
|
+
seq = torch.randn(2, seq_len, 384)
|
149
|
+
parallel_retrieved = mem(seq)
|
150
|
+
|
151
|
+
assert seq.shape == parallel_retrieved.shape
|
152
|
+
|
153
|
+
mem_model_state = None
|
154
|
+
cache_store_seq = None
|
155
|
+
sequential_retrieved = []
|
156
|
+
|
157
|
+
for ind, token in enumerate(seq.unbind(dim = 1)):
|
158
|
+
|
159
|
+
one_retrieved, cache_store_seq, mem_model_state = mem.forward_inference(
|
160
|
+
token,
|
161
|
+
seq_index = ind,
|
162
|
+
cache_store_seq = cache_store_seq,
|
163
|
+
mem_model_state = mem_model_state
|
164
|
+
)
|
165
|
+
|
166
|
+
sequential_retrieved.append(one_retrieved)
|
167
|
+
|
168
|
+
sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
|
169
|
+
|
170
|
+
assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)
|
171
|
+
|
136
172
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
137
173
|
@pytest.mark.parametrize('sliding', (True, False))
|
138
174
|
def test_flex(
|
@@ -157,3 +193,23 @@ def test_flex(
|
|
157
193
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
158
194
|
|
159
195
|
assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
|
196
|
+
|
197
|
+
def test_assoc_scan():
|
198
|
+
from titans_pytorch.titans import AssocScan
|
199
|
+
import torch.nn.functional as F
|
200
|
+
|
201
|
+
scan = AssocScan()
|
202
|
+
|
203
|
+
gates = torch.randn(2, 1024, 512).sigmoid()
|
204
|
+
inputs = torch.randn(2, 1024, 512)
|
205
|
+
|
206
|
+
output = scan(gates, inputs)
|
207
|
+
|
208
|
+
gates1, gates2 = gates[:, :512], gates[:, 512:]
|
209
|
+
inputs1, inputs2 = inputs[:, :512], inputs[:, 512:]
|
210
|
+
|
211
|
+
first_half = scan(gates1, inputs1)
|
212
|
+
|
213
|
+
second_half = scan(gates2, inputs2, prev = inputs2[:, -1])
|
214
|
+
|
215
|
+
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
|
@@ -44,6 +44,16 @@ def default(v, d):
|
|
44
44
|
def xnor(x, y):
|
45
45
|
return not (x ^ y)
|
46
46
|
|
47
|
+
def safe_cat(inputs, dim = -2):
|
48
|
+
inputs = tuple(filter(exists, inputs))
|
49
|
+
|
50
|
+
if len(inputs) == 0:
|
51
|
+
return None
|
52
|
+
elif len(inputs) == 1:
|
53
|
+
return inputs[0]
|
54
|
+
|
55
|
+
return cat(inputs, dim = dim)
|
56
|
+
|
47
57
|
def identity(t):
|
48
58
|
return t
|
49
59
|
|
@@ -314,7 +324,11 @@ class AssocScan(Module):
|
|
314
324
|
super().__init__()
|
315
325
|
self.use_accelerated = use_accelerated
|
316
326
|
|
317
|
-
def forward(self, gates, inputs):
|
327
|
+
def forward(self, gates, inputs, prev = None):
|
328
|
+
|
329
|
+
if exists(prev):
|
330
|
+
inputs, _ = pack([prev, inputs], 'b * d')
|
331
|
+
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
318
332
|
|
319
333
|
if not self.use_accelerated:
|
320
334
|
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
@@ -678,7 +692,7 @@ class NeuralMemory(Module):
|
|
678
692
|
def retrieve_memories(
|
679
693
|
self,
|
680
694
|
seq,
|
681
|
-
past_weights: dict[str, Tensor]
|
695
|
+
past_weights: dict[str, Tensor],
|
682
696
|
chunk_size = None
|
683
697
|
):
|
684
698
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
@@ -700,13 +714,7 @@ class NeuralMemory(Module):
|
|
700
714
|
# the parameters of the memory model stores the memories of the key / values
|
701
715
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
702
716
|
|
703
|
-
curr_weights = TensorDict(
|
704
|
-
|
705
|
-
if exists(past_weights):
|
706
|
-
past_weights = TensorDict(past_weights)
|
707
|
-
assert past_weights.keys() == curr_weights.keys()
|
708
|
-
|
709
|
-
curr_weights = curr_weights + past_weights
|
717
|
+
curr_weights = TensorDict(past_weights)
|
710
718
|
|
711
719
|
# sequence Float['b n d'] to queries
|
712
720
|
|
@@ -753,6 +761,56 @@ class NeuralMemory(Module):
|
|
753
761
|
|
754
762
|
return values[:, :seq_len]
|
755
763
|
|
764
|
+
def forward_inference(
|
765
|
+
self,
|
766
|
+
token: Tensor,
|
767
|
+
seq_index = None, # the index of the token in the sequence, starts at 0
|
768
|
+
mem_model_state = None,
|
769
|
+
cache_store_seq = None
|
770
|
+
):
|
771
|
+
seq_index = default(seq_index, 0)
|
772
|
+
curr_seq_len = seq_index + 1
|
773
|
+
batch = token.shape[0]
|
774
|
+
|
775
|
+
if token.ndim == 2:
|
776
|
+
token = rearrange(token, 'b d -> b 1 d')
|
777
|
+
|
778
|
+
# init memory model if needed
|
779
|
+
|
780
|
+
if not exists(mem_model_state):
|
781
|
+
mem_model_state = self.init_weights_and_momentum()
|
782
|
+
|
783
|
+
# increment the sequence cache which is at most the chunk size
|
784
|
+
|
785
|
+
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
786
|
+
|
787
|
+
# early return empty memory, when no memories are stored for steps < first chunk size
|
788
|
+
|
789
|
+
if curr_seq_len < self.chunk_size:
|
790
|
+
empty_mem = self.init_empty_memory_embed(batch, 1)
|
791
|
+
|
792
|
+
return empty_mem, cache_store_seq, mem_model_state
|
793
|
+
|
794
|
+
# store if storage sequence cache hits the chunk size
|
795
|
+
|
796
|
+
store_seq_cache_len = cache_store_seq.shape[-2]
|
797
|
+
|
798
|
+
if store_seq_cache_len == self.chunk_size:
|
799
|
+
updates, _ = self.store_memories(cache_store_seq, mem_model_state)
|
800
|
+
|
801
|
+
past_weights, past_momentum = mem_model_state
|
802
|
+
mem_model_state = (past_weights + updates, past_momentum)
|
803
|
+
|
804
|
+
cache_store_seq = None
|
805
|
+
|
806
|
+
# retrieve
|
807
|
+
|
808
|
+
past_weights, _ = mem_model_state
|
809
|
+
|
810
|
+
retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
|
811
|
+
|
812
|
+
return retrieved, cache_store_seq, mem_model_state
|
813
|
+
|
756
814
|
def forward(
|
757
815
|
self,
|
758
816
|
seq,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|