titans-pytorch 0.1.21__tar.gz → 0.1.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/PKG-INFO +2 -3
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/README.md +1 -2
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/pyproject.toml +1 -1
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/tests/test_titans.py +56 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/titans_pytorch/titans.py +108 -19
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/.gitignore +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/LICENSE +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/data/README.md +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/fig1.png +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/fig2.png +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.1.21 → titans_pytorch-0.1.23}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.23
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
|
|
78
78
|
|
79
79
|
mem = NeuralMemory(
|
80
80
|
dim = 384,
|
81
|
-
chunk_size = 64
|
82
|
-
pre_rmsnorm = True
|
81
|
+
chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
|
83
82
|
).cuda()
|
84
83
|
|
85
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -24,8 +24,7 @@ from titans_pytorch import NeuralMemory
|
|
24
24
|
|
25
25
|
mem = NeuralMemory(
|
26
26
|
dim = 384,
|
27
|
-
chunk_size = 64
|
28
|
-
pre_rmsnorm = True
|
27
|
+
chunk_size = 64 # set to smaller chunk size for better perf on smaller sequence lengths (but more memory usage)
|
29
28
|
).cuda()
|
30
29
|
|
31
30
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -8,6 +8,9 @@ from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, M
|
|
8
8
|
def exists(v):
|
9
9
|
return v is not None
|
10
10
|
|
11
|
+
def diff(x, y):
|
12
|
+
return (x - y).abs().amax()
|
13
|
+
|
11
14
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
12
15
|
@pytest.mark.parametrize('silu', (False, True))
|
13
16
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
@@ -133,6 +136,39 @@ def test_mac_sampling(sliding):
|
|
133
136
|
|
134
137
|
assert torch.allclose(sampled, sampled_with_cache)
|
135
138
|
|
139
|
+
@pytest.mark.parametrize('seq_len', (2, 64))
|
140
|
+
def test_neural_mem_inference(
|
141
|
+
seq_len
|
142
|
+
):
|
143
|
+
mem = NeuralMemory(
|
144
|
+
dim = 384,
|
145
|
+
chunk_size = 64,
|
146
|
+
)
|
147
|
+
|
148
|
+
seq = torch.randn(2, seq_len, 384)
|
149
|
+
parallel_retrieved = mem(seq)
|
150
|
+
|
151
|
+
assert seq.shape == parallel_retrieved.shape
|
152
|
+
|
153
|
+
mem_model_state = None
|
154
|
+
cache_store_seq = None
|
155
|
+
sequential_retrieved = []
|
156
|
+
|
157
|
+
for ind, token in enumerate(seq.unbind(dim = 1)):
|
158
|
+
|
159
|
+
one_retrieved, cache_store_seq, mem_model_state = mem.forward_inference(
|
160
|
+
token,
|
161
|
+
seq_index = ind,
|
162
|
+
cache_store_seq = cache_store_seq,
|
163
|
+
mem_model_state = mem_model_state
|
164
|
+
)
|
165
|
+
|
166
|
+
sequential_retrieved.append(one_retrieved)
|
167
|
+
|
168
|
+
sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)
|
169
|
+
|
170
|
+
assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)
|
171
|
+
|
136
172
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
137
173
|
@pytest.mark.parametrize('sliding', (True, False))
|
138
174
|
def test_flex(
|
@@ -157,3 +193,23 @@ def test_flex(
|
|
157
193
|
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
158
194
|
|
159
195
|
assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
|
196
|
+
|
197
|
+
def test_assoc_scan():
|
198
|
+
from titans_pytorch.titans import AssocScan
|
199
|
+
import torch.nn.functional as F
|
200
|
+
|
201
|
+
scan = AssocScan()
|
202
|
+
|
203
|
+
gates = torch.randn(2, 1024, 512).sigmoid()
|
204
|
+
inputs = torch.randn(2, 1024, 512)
|
205
|
+
|
206
|
+
output = scan(gates, inputs)
|
207
|
+
|
208
|
+
gates1, gates2 = gates[:, :512], gates[:, 512:]
|
209
|
+
inputs1, inputs2 = inputs[:, :512], inputs[:, 512:]
|
210
|
+
|
211
|
+
first_half = scan(gates1, inputs1)
|
212
|
+
|
213
|
+
second_half = scan(gates2, inputs2, prev = inputs2[:, -1])
|
214
|
+
|
215
|
+
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
|
@@ -41,6 +41,19 @@ def exists(v):
|
|
41
41
|
def default(v, d):
|
42
42
|
return v if exists(v) else d
|
43
43
|
|
44
|
+
def xnor(x, y):
|
45
|
+
return not (x ^ y)
|
46
|
+
|
47
|
+
def safe_cat(inputs, dim = -2):
|
48
|
+
inputs = tuple(filter(exists, inputs))
|
49
|
+
|
50
|
+
if len(inputs) == 0:
|
51
|
+
return None
|
52
|
+
elif len(inputs) == 1:
|
53
|
+
return inputs[0]
|
54
|
+
|
55
|
+
return cat(inputs, dim = dim)
|
56
|
+
|
44
57
|
def identity(t):
|
45
58
|
return t
|
46
59
|
|
@@ -311,7 +324,11 @@ class AssocScan(Module):
|
|
311
324
|
super().__init__()
|
312
325
|
self.use_accelerated = use_accelerated
|
313
326
|
|
314
|
-
def forward(self, gates, inputs):
|
327
|
+
def forward(self, gates, inputs, prev = None):
|
328
|
+
|
329
|
+
if exists(prev):
|
330
|
+
inputs, _ = pack([prev, inputs], 'b * d')
|
331
|
+
gates = pad_at_dim(gates, (1, 0), value = 1., dim = -2)
|
315
332
|
|
316
333
|
if not self.use_accelerated:
|
317
334
|
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
@@ -366,6 +383,7 @@ class NeuralMemory(Module):
|
|
366
383
|
pre_rmsnorm = True,
|
367
384
|
post_rmsnorm = True,
|
368
385
|
qk_rmsnorm = False,
|
386
|
+
accept_value_residual = False,
|
369
387
|
learned_mem_model_weights = True,
|
370
388
|
max_grad_norm: float | None = None,
|
371
389
|
use_accelerated_scan = False,
|
@@ -399,7 +417,7 @@ class NeuralMemory(Module):
|
|
399
417
|
|
400
418
|
self.heads = heads
|
401
419
|
|
402
|
-
self.split_heads = Rearrange('b n (h d) ->
|
420
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
403
421
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
404
422
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
405
423
|
|
@@ -448,6 +466,14 @@ class NeuralMemory(Module):
|
|
448
466
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
449
467
|
self.store_memory_loss_fn = store_memory_loss_fn
|
450
468
|
|
469
|
+
# value residual learning
|
470
|
+
|
471
|
+
self.learned_value_residual = Sequential(
|
472
|
+
LinearNoBias(dim, heads),
|
473
|
+
Rearrange('b n h -> b h n 1'),
|
474
|
+
nn.Sigmoid()
|
475
|
+
) if accept_value_residual else None
|
476
|
+
|
451
477
|
# empty memory embed
|
452
478
|
|
453
479
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
@@ -529,8 +555,11 @@ class NeuralMemory(Module):
|
|
529
555
|
seq,
|
530
556
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
531
557
|
return_aux_kv_loss = False,
|
532
|
-
chunk_size = None
|
558
|
+
chunk_size = None,
|
559
|
+
value_residual = None
|
533
560
|
):
|
561
|
+
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
562
|
+
|
534
563
|
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
535
564
|
|
536
565
|
# handle edge case
|
@@ -585,9 +614,17 @@ class NeuralMemory(Module):
|
|
585
614
|
|
586
615
|
keys = self.k_norm(keys)
|
587
616
|
|
617
|
+
# maybe value residual learning
|
618
|
+
|
619
|
+
orig_values = values
|
620
|
+
|
621
|
+
if exists(self.learned_value_residual):
|
622
|
+
mix = self.learned_value_residual(seq)
|
623
|
+
values = values.lerp(value_residual, mix)
|
624
|
+
|
588
625
|
# take care of chunking
|
589
626
|
|
590
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
627
|
+
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
591
628
|
|
592
629
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
593
630
|
|
@@ -645,15 +682,17 @@ class NeuralMemory(Module):
|
|
645
682
|
|
646
683
|
last_update = updates.apply(lambda t: t[:, -1])
|
647
684
|
|
685
|
+
output = (updates, orig_values)
|
686
|
+
|
648
687
|
if not return_aux_kv_loss:
|
649
|
-
return
|
688
|
+
return output
|
650
689
|
|
651
|
-
return
|
690
|
+
return output, aux_kv_recon_loss.mean()
|
652
691
|
|
653
692
|
def retrieve_memories(
|
654
693
|
self,
|
655
694
|
seq,
|
656
|
-
past_weights: dict[str, Tensor]
|
695
|
+
past_weights: dict[str, Tensor],
|
657
696
|
chunk_size = None
|
658
697
|
):
|
659
698
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
@@ -675,13 +714,7 @@ class NeuralMemory(Module):
|
|
675
714
|
# the parameters of the memory model stores the memories of the key / values
|
676
715
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
677
716
|
|
678
|
-
curr_weights = TensorDict(
|
679
|
-
|
680
|
-
if exists(past_weights):
|
681
|
-
past_weights = TensorDict(past_weights)
|
682
|
-
assert past_weights.keys() == curr_weights.keys()
|
683
|
-
|
684
|
-
curr_weights = curr_weights + past_weights
|
717
|
+
curr_weights = TensorDict(past_weights)
|
685
718
|
|
686
719
|
# sequence Float['b n d'] to queries
|
687
720
|
|
@@ -698,7 +731,7 @@ class NeuralMemory(Module):
|
|
698
731
|
# fetch values from memory model
|
699
732
|
|
700
733
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
701
|
-
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
734
|
+
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
702
735
|
|
703
736
|
# forward functional call
|
704
737
|
|
@@ -728,6 +761,56 @@ class NeuralMemory(Module):
|
|
728
761
|
|
729
762
|
return values[:, :seq_len]
|
730
763
|
|
764
|
+
def forward_inference(
|
765
|
+
self,
|
766
|
+
token: Tensor,
|
767
|
+
seq_index = None, # the index of the token in the sequence, starts at 0
|
768
|
+
mem_model_state = None,
|
769
|
+
cache_store_seq = None
|
770
|
+
):
|
771
|
+
seq_index = default(seq_index, 0)
|
772
|
+
curr_seq_len = seq_index + 1
|
773
|
+
batch = token.shape[0]
|
774
|
+
|
775
|
+
if token.ndim == 2:
|
776
|
+
token = rearrange(token, 'b d -> b 1 d')
|
777
|
+
|
778
|
+
# init memory model if needed
|
779
|
+
|
780
|
+
if not exists(mem_model_state):
|
781
|
+
mem_model_state = self.init_weights_and_momentum()
|
782
|
+
|
783
|
+
# increment the sequence cache which is at most the chunk size
|
784
|
+
|
785
|
+
cache_store_seq = safe_cat((cache_store_seq, token), dim = -2)
|
786
|
+
|
787
|
+
# early return empty memory, when no memories are stored for steps < first chunk size
|
788
|
+
|
789
|
+
if curr_seq_len < self.chunk_size:
|
790
|
+
empty_mem = self.init_empty_memory_embed(batch, 1)
|
791
|
+
|
792
|
+
return empty_mem, cache_store_seq, mem_model_state
|
793
|
+
|
794
|
+
# store if storage sequence cache hits the chunk size
|
795
|
+
|
796
|
+
store_seq_cache_len = cache_store_seq.shape[-2]
|
797
|
+
|
798
|
+
if store_seq_cache_len == self.chunk_size:
|
799
|
+
updates, _ = self.store_memories(cache_store_seq, mem_model_state)
|
800
|
+
|
801
|
+
past_weights, past_momentum = mem_model_state
|
802
|
+
mem_model_state = (past_weights + updates, past_momentum)
|
803
|
+
|
804
|
+
cache_store_seq = None
|
805
|
+
|
806
|
+
# retrieve
|
807
|
+
|
808
|
+
past_weights, _ = mem_model_state
|
809
|
+
|
810
|
+
retrieved = self.retrieve_memories(token, past_weights, chunk_size = 1)
|
811
|
+
|
812
|
+
return retrieved, cache_store_seq, mem_model_state
|
813
|
+
|
731
814
|
def forward(
|
732
815
|
self,
|
733
816
|
seq,
|
@@ -735,7 +818,8 @@ class NeuralMemory(Module):
|
|
735
818
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
736
819
|
return_aux_kv_loss = False,
|
737
820
|
chunk_size = None,
|
738
|
-
store_chunk_size = None
|
821
|
+
store_chunk_size = None,
|
822
|
+
return_values = False
|
739
823
|
):
|
740
824
|
batch, seq_len = seq.shape[:2]
|
741
825
|
|
@@ -756,13 +840,18 @@ class NeuralMemory(Module):
|
|
756
840
|
store_seq = default(store_seq, seq)
|
757
841
|
store_chunk_size = default(store_chunk_size, chunk_size)
|
758
842
|
|
759
|
-
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
843
|
+
(updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
760
844
|
|
761
845
|
past_weights, _ = past_state
|
762
846
|
|
763
847
|
retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
|
764
848
|
|
849
|
+
output = retrieved
|
850
|
+
|
851
|
+
if return_values:
|
852
|
+
output = (retrieved, values)
|
853
|
+
|
765
854
|
if not return_aux_kv_loss:
|
766
|
-
return
|
855
|
+
return output
|
767
856
|
|
768
|
-
return
|
857
|
+
return output, aux_kv_recon_loss
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|