titans-pytorch 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -1
- titans_pytorch/mac_transformer.py +33 -19
- titans_pytorch/{titans.py → neural_memory.py} +53 -8
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.32.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.32.dist-info/RECORD +8 -0
- titans_pytorch-0.1.30.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.32.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.30.dist-info → titans_pytorch-0.1.32.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
65
65
|
|
66
66
|
# proposed neural memory
|
67
67
|
|
68
|
-
from titans_pytorch.
|
68
|
+
from titans_pytorch.neural_memory import NeuralMemory
|
69
69
|
|
70
70
|
# constants
|
71
71
|
|
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
106
106
|
zeros = ((0, 0) * dims_from_right)
|
107
107
|
return F.pad(t, (*zeros, *pad), value = value)
|
108
108
|
|
109
|
-
def pad_and_segment_with_inverse(
|
109
|
+
def pad_and_segment_with_inverse(
|
110
|
+
seq,
|
111
|
+
segment_len,
|
112
|
+
fold_into_batch = True,
|
113
|
+
):
|
110
114
|
batch, seq_len = seq.shape[:2]
|
111
115
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
112
116
|
|
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
119
123
|
if fold_into_batch:
|
120
124
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
121
125
|
|
122
|
-
|
126
|
+
shape = seq.shape
|
127
|
+
|
128
|
+
def inverse(out):
|
129
|
+
unchanged_shape = out.shape == shape
|
130
|
+
|
123
131
|
if fold_into_batch:
|
124
132
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
125
133
|
|
126
|
-
if needs_pad and
|
134
|
+
if needs_pad and unchanged_shape:
|
127
135
|
out = out[..., :-padding, :]
|
128
136
|
|
129
137
|
return out
|
@@ -582,13 +590,8 @@ class MemoryAsContextTransformer(Module):
|
|
582
590
|
self,
|
583
591
|
seq_index
|
584
592
|
):
|
585
|
-
total_segment_len = self.attn_window_size
|
586
|
-
|
587
|
-
seq = seq_index + 1
|
588
|
-
seq -= int((seq % total_segment_len) == 0)
|
589
|
-
last_segment_len = round_down_multiple(seq, total_segment_len)
|
590
|
-
segment_seq = seq - last_segment_len
|
591
|
-
return (segment_seq - self.segment_len) > 0
|
593
|
+
total_segment_len, segment_len = self.attn_window_size, self.segment_len
|
594
|
+
return ((seq_index % total_segment_len + 1) - segment_len) > 0
|
592
595
|
|
593
596
|
def seq_len_with_longterm_mem(
|
594
597
|
self,
|
@@ -597,7 +600,7 @@ class MemoryAsContextTransformer(Module):
|
|
597
600
|
assert seq_len > 0
|
598
601
|
|
599
602
|
segment_len, num_mem = self.segment_len, self.num_longterm_mem_tokens
|
600
|
-
return
|
603
|
+
return ((seq_len - 1) // segment_len) * num_mem + seq_len
|
601
604
|
|
602
605
|
@torch.no_grad()
|
603
606
|
def sample(
|
@@ -695,7 +698,7 @@ class MemoryAsContextTransformer(Module):
|
|
695
698
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
696
699
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
697
700
|
|
698
|
-
x = inverse_segment(x
|
701
|
+
x = inverse_segment(x)
|
699
702
|
|
700
703
|
# splice out unneeded tokens from padding for longterm mems
|
701
704
|
|
@@ -723,9 +726,9 @@ class MemoryAsContextTransformer(Module):
|
|
723
726
|
is_inferencing = exists(cache)
|
724
727
|
|
725
728
|
if not exists(cache):
|
726
|
-
cache = (None, None)
|
729
|
+
cache = (seq_len_with_mem - 1, None, None)
|
727
730
|
|
728
|
-
kv_caches, neural_mem_caches = cache
|
731
|
+
inference_seq_index, kv_caches, neural_mem_caches = cache
|
729
732
|
|
730
733
|
kv_caches = iter(default(kv_caches, []))
|
731
734
|
neural_mem_caches = iter(default(neural_mem_caches, []))
|
@@ -744,7 +747,8 @@ class MemoryAsContextTransformer(Module):
|
|
744
747
|
# when inferencing, only do one token at a time
|
745
748
|
|
746
749
|
if is_inferencing:
|
747
|
-
|
750
|
+
ind = inference_seq_index
|
751
|
+
x = x[:, ind:(ind + 1)]
|
748
752
|
|
749
753
|
# expand and reduce streams for hyper connections
|
750
754
|
|
@@ -763,14 +767,13 @@ class MemoryAsContextTransformer(Module):
|
|
763
767
|
mem_input, add_residual = mem_hyper_conn(x)
|
764
768
|
|
765
769
|
if not is_inferencing:
|
766
|
-
retrieved, mem_kv_aux_loss = mem(
|
770
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
767
771
|
mem_input,
|
768
772
|
return_aux_kv_loss = True
|
769
773
|
)
|
770
774
|
|
771
775
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
772
776
|
|
773
|
-
next_neural_mem_cache = (seq_len, None, None, None)
|
774
777
|
else:
|
775
778
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
776
779
|
mem_input,
|
@@ -817,6 +820,17 @@ class MemoryAsContextTransformer(Module):
|
|
817
820
|
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
818
821
|
next_kv_caches = next_kv_caches[..., 0:0, :]
|
819
822
|
|
823
|
+
next_cache = (
|
824
|
+
inference_seq_index + 1,
|
825
|
+
next_kv_caches,
|
826
|
+
next_neural_mem_caches
|
827
|
+
)
|
828
|
+
|
829
|
+
is_longterm_mem = self.seq_index_is_longterm(inference_seq_index)
|
830
|
+
|
831
|
+
if is_inferencing and is_longterm_mem:
|
832
|
+
return None, next_cache
|
833
|
+
|
820
834
|
# hyper connection reducing of streams
|
821
835
|
|
822
836
|
x = self.reduce_streams(x)
|
@@ -843,7 +857,7 @@ class MemoryAsContextTransformer(Module):
|
|
843
857
|
if not return_cache:
|
844
858
|
return logits
|
845
859
|
|
846
|
-
return logits,
|
860
|
+
return logits, next_cache
|
847
861
|
|
848
862
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
849
863
|
|
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
|
|
38
38
|
def exists(v):
|
39
39
|
return v is not None
|
40
40
|
|
41
|
-
def default(
|
42
|
-
|
41
|
+
def default(*args):
|
42
|
+
for arg in args:
|
43
|
+
if exists(arg):
|
44
|
+
return arg
|
45
|
+
return None
|
43
46
|
|
44
47
|
def xnor(x, y):
|
45
48
|
return not (x ^ y)
|
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
|
|
468
471
|
weighted_loss = loss * loss_weights
|
469
472
|
return weighted_loss.sum(), weighted_loss.mean()
|
470
473
|
|
471
|
-
|
474
|
+
# two functions
|
475
|
+
|
476
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
477
|
+
|
478
|
+
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
|
479
|
+
self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
|
472
480
|
|
473
481
|
# queries for retrieving from the model
|
474
482
|
|
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
|
|
561
569
|
seq,
|
562
570
|
weights: dict[str, Tensor],
|
563
571
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
572
|
+
prev_layer_updates: dict[str, Tensor] | None = None,
|
564
573
|
return_aux_kv_loss = False,
|
565
574
|
chunk_size = None,
|
566
575
|
value_residual = None
|
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
|
|
583
592
|
|
584
593
|
seq = seq[:, :round_down_seq_len]
|
585
594
|
|
595
|
+
# per sample grad function
|
596
|
+
|
597
|
+
per_sample_grad_fn = self.per_sample_grad_fn
|
598
|
+
|
586
599
|
# weights of the memory network
|
587
600
|
|
588
601
|
weights = TensorDict(weights)
|
589
602
|
|
603
|
+
# allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
|
604
|
+
# think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
|
605
|
+
# improvise (or perhaps correcting to) a solution
|
606
|
+
|
607
|
+
if exists(prev_layer_updates):
|
608
|
+
prev_layer_updates = TensorDict(weights)
|
609
|
+
|
610
|
+
weights = weights + prev_layer_updates
|
611
|
+
|
612
|
+
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
613
|
+
|
590
614
|
# derive learned hparams for optimization of memory network
|
591
615
|
|
592
616
|
adaptive_lr = self.to_adaptive_step(seq)
|
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
|
|
635
659
|
|
636
660
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
637
661
|
|
638
|
-
grads, aux_kv_recon_loss =
|
662
|
+
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
639
663
|
|
640
664
|
grads = TensorDict(grads)
|
641
665
|
|
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
|
|
781
805
|
|
782
806
|
return values[:, :seq_len]
|
783
807
|
|
808
|
+
@torch.no_grad()
|
784
809
|
def forward_inference(
|
785
810
|
self,
|
786
811
|
token: Tensor,
|
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
|
|
854
879
|
return_aux_kv_loss = False,
|
855
880
|
chunk_size = None,
|
856
881
|
store_chunk_size = None,
|
857
|
-
return_values = False
|
882
|
+
return_values = False,
|
883
|
+
return_next_state = False
|
858
884
|
):
|
859
885
|
batch, seq_len = seq.shape[:2]
|
860
886
|
|
861
887
|
if seq_len < self.retrieve_chunk_size:
|
862
888
|
out = self.init_empty_memory_embed(batch, seq_len)
|
863
889
|
|
890
|
+
next_store_state = (seq_len, seq, None, None)
|
891
|
+
|
892
|
+
out = (out, next_store_state)
|
893
|
+
|
864
894
|
if not return_aux_kv_loss:
|
865
895
|
return out
|
866
896
|
|
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
|
|
870
900
|
mem_model_weights = self.init_weights()
|
871
901
|
|
872
902
|
store_seq = default(store_seq, seq)
|
873
|
-
|
903
|
+
|
904
|
+
store_seq_len = store_seq.shape[-2]
|
905
|
+
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
906
|
+
remainder = store_seq_len % store_chunk_size
|
874
907
|
|
875
908
|
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
876
909
|
|
877
910
|
retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
|
878
911
|
|
879
|
-
|
912
|
+
# determine state for the storing of memories
|
913
|
+
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
914
|
+
|
915
|
+
cache_store_seq = None
|
916
|
+
|
917
|
+
if remainder > 0:
|
918
|
+
cache_store_seq = store_seq[:, -remainder:]
|
919
|
+
|
920
|
+
updates = updates.apply(lambda t: t[:, -1:])
|
921
|
+
|
922
|
+
next_store_state = (seq_len, cache_store_seq, next_state, updates)
|
923
|
+
|
924
|
+
output = (retrieved, next_store_state)
|
880
925
|
|
881
926
|
if return_values:
|
882
|
-
output = (
|
927
|
+
output = (*output, values)
|
883
928
|
|
884
929
|
if not return_aux_kv_loss:
|
885
930
|
return output
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.32
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -82,7 +82,7 @@ mem = NeuralMemory(
|
|
82
82
|
).cuda()
|
83
83
|
|
84
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
85
|
-
retrieved = mem(seq)
|
85
|
+
retrieved, mem_state = mem(seq)
|
86
86
|
|
87
87
|
assert seq.shape == retrieved.shape
|
88
88
|
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
|
4
|
+
titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
|
5
|
+
titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
|
6
|
+
titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.32.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=RRLdVa8z-2IWbhhmRGfoNBycwaL32aMbpqutzmSQqpc,24575
|
4
|
-
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
-
titans_pytorch-0.1.30.dist-info/METADATA,sha256=o5flkZ0hNhZE06bSKVEFpbrkhuWB9putcaL_MZ0sJHA,6815
|
6
|
-
titans_pytorch-0.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.30.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.30.dist-info/RECORD,,
|
File without changes
|
File without changes
|