titans-pytorch 0.1.21__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +41 -10
- {titans_pytorch-0.1.21.dist-info → titans_pytorch-0.1.22.dist-info}/METADATA +2 -3
- titans_pytorch-0.1.22.dist-info/RECORD +8 -0
- titans_pytorch-0.1.21.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.21.dist-info → titans_pytorch-0.1.22.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.21.dist-info → titans_pytorch-0.1.22.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
@@ -41,6 +41,9 @@ def exists(v):
|
|
41
41
|
def default(v, d):
|
42
42
|
return v if exists(v) else d
|
43
43
|
|
44
|
+
def xnor(x, y):
|
45
|
+
return not (x ^ y)
|
46
|
+
|
44
47
|
def identity(t):
|
45
48
|
return t
|
46
49
|
|
@@ -366,6 +369,7 @@ class NeuralMemory(Module):
|
|
366
369
|
pre_rmsnorm = True,
|
367
370
|
post_rmsnorm = True,
|
368
371
|
qk_rmsnorm = False,
|
372
|
+
accept_value_residual = False,
|
369
373
|
learned_mem_model_weights = True,
|
370
374
|
max_grad_norm: float | None = None,
|
371
375
|
use_accelerated_scan = False,
|
@@ -399,7 +403,7 @@ class NeuralMemory(Module):
|
|
399
403
|
|
400
404
|
self.heads = heads
|
401
405
|
|
402
|
-
self.split_heads = Rearrange('b n (h d) ->
|
406
|
+
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
403
407
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
404
408
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
405
409
|
|
@@ -448,6 +452,14 @@ class NeuralMemory(Module):
|
|
448
452
|
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
449
453
|
self.store_memory_loss_fn = store_memory_loss_fn
|
450
454
|
|
455
|
+
# value residual learning
|
456
|
+
|
457
|
+
self.learned_value_residual = Sequential(
|
458
|
+
LinearNoBias(dim, heads),
|
459
|
+
Rearrange('b n h -> b h n 1'),
|
460
|
+
nn.Sigmoid()
|
461
|
+
) if accept_value_residual else None
|
462
|
+
|
451
463
|
# empty memory embed
|
452
464
|
|
453
465
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
@@ -529,8 +541,11 @@ class NeuralMemory(Module):
|
|
529
541
|
seq,
|
530
542
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
|
531
543
|
return_aux_kv_loss = False,
|
532
|
-
chunk_size = None
|
544
|
+
chunk_size = None,
|
545
|
+
value_residual = None
|
533
546
|
):
|
547
|
+
assert xnor(exists(value_residual), exists(self.learned_value_residual))
|
548
|
+
|
534
549
|
seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)
|
535
550
|
|
536
551
|
# handle edge case
|
@@ -585,9 +600,17 @@ class NeuralMemory(Module):
|
|
585
600
|
|
586
601
|
keys = self.k_norm(keys)
|
587
602
|
|
603
|
+
# maybe value residual learning
|
604
|
+
|
605
|
+
orig_values = values
|
606
|
+
|
607
|
+
if exists(self.learned_value_residual):
|
608
|
+
mix = self.learned_value_residual(seq)
|
609
|
+
values = values.lerp(value_residual, mix)
|
610
|
+
|
588
611
|
# take care of chunking
|
589
612
|
|
590
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
613
|
+
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))
|
591
614
|
|
592
615
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
593
616
|
|
@@ -645,10 +668,12 @@ class NeuralMemory(Module):
|
|
645
668
|
|
646
669
|
last_update = updates.apply(lambda t: t[:, -1])
|
647
670
|
|
671
|
+
output = (updates, orig_values)
|
672
|
+
|
648
673
|
if not return_aux_kv_loss:
|
649
|
-
return
|
674
|
+
return output
|
650
675
|
|
651
|
-
return
|
676
|
+
return output, aux_kv_recon_loss.mean()
|
652
677
|
|
653
678
|
def retrieve_memories(
|
654
679
|
self,
|
@@ -698,7 +723,7 @@ class NeuralMemory(Module):
|
|
698
723
|
# fetch values from memory model
|
699
724
|
|
700
725
|
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
701
|
-
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
726
|
+
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)
|
702
727
|
|
703
728
|
# forward functional call
|
704
729
|
|
@@ -735,7 +760,8 @@ class NeuralMemory(Module):
|
|
735
760
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
736
761
|
return_aux_kv_loss = False,
|
737
762
|
chunk_size = None,
|
738
|
-
store_chunk_size = None
|
763
|
+
store_chunk_size = None,
|
764
|
+
return_values = False
|
739
765
|
):
|
740
766
|
batch, seq_len = seq.shape[:2]
|
741
767
|
|
@@ -756,13 +782,18 @@ class NeuralMemory(Module):
|
|
756
782
|
store_seq = default(store_seq, seq)
|
757
783
|
store_chunk_size = default(store_chunk_size, chunk_size)
|
758
784
|
|
759
|
-
updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
785
|
+
(updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
760
786
|
|
761
787
|
past_weights, _ = past_state
|
762
788
|
|
763
789
|
retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)
|
764
790
|
|
791
|
+
output = retrieved
|
792
|
+
|
793
|
+
if return_values:
|
794
|
+
output = (retrieved, values)
|
795
|
+
|
765
796
|
if not return_aux_kv_loss:
|
766
|
-
return
|
797
|
+
return output
|
767
798
|
|
768
|
-
return
|
799
|
+
return output, aux_kv_recon_loss
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.22
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -78,8 +78,7 @@ from titans_pytorch import NeuralMemory
|
|
78
78
|
|
79
79
|
mem = NeuralMemory(
|
80
80
|
dim = 384,
|
81
|
-
chunk_size = 64
|
82
|
-
pre_rmsnorm = True
|
81
|
+
chunk_size = 64
|
83
82
|
).cuda()
|
84
83
|
|
85
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
+
titans_pytorch/titans.py,sha256=7PGnZxdKq6T6e51RL7-QV43wp-46YmrytTZLt0McMco,23407
|
5
|
+
titans_pytorch-0.1.22.dist-info/METADATA,sha256=HCOAqLK605c8R2mvgQ80kwE9jayZ2CwJqHLsJtFx7Vs,6718
|
6
|
+
titans_pytorch-0.1.22.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.22.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.22.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
4
|
-
titans_pytorch/titans.py,sha256=YYt6O5EiBVvyxWM4R1JuLLJH3bGm1V-74aB7VhbsWQ0,22577
|
5
|
-
titans_pytorch-0.1.21.dist-info/METADATA,sha256=ixbJisycB0MgSIcOvDRM1PIMs3l1TM_AmQ88aWZYEsM,6742
|
6
|
-
titans_pytorch-0.1.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.21.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|