titans-pytorch 0.1.31__py3-none-any.whl → 0.1.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -1
- titans_pytorch/mac_transformer.py +15 -8
- titans_pytorch/{titans.py → neural_memory.py} +86 -13
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.33.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.33.dist-info/RECORD +8 -0
- titans_pytorch-0.1.31.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.33.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.33.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
65
65
|
|
66
66
|
# proposed neural memory
|
67
67
|
|
68
|
-
from titans_pytorch.
|
68
|
+
from titans_pytorch.neural_memory import NeuralMemory
|
69
69
|
|
70
70
|
# constants
|
71
71
|
|
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
106
106
|
zeros = ((0, 0) * dims_from_right)
|
107
107
|
return F.pad(t, (*zeros, *pad), value = value)
|
108
108
|
|
109
|
-
def pad_and_segment_with_inverse(
|
109
|
+
def pad_and_segment_with_inverse(
|
110
|
+
seq,
|
111
|
+
segment_len,
|
112
|
+
fold_into_batch = True,
|
113
|
+
):
|
110
114
|
batch, seq_len = seq.shape[:2]
|
111
115
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
112
116
|
|
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
119
123
|
if fold_into_batch:
|
120
124
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
121
125
|
|
122
|
-
|
126
|
+
shape = seq.shape
|
127
|
+
|
128
|
+
def inverse(out):
|
129
|
+
unchanged_shape = out.shape == shape
|
130
|
+
|
123
131
|
if fold_into_batch:
|
124
132
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
125
133
|
|
126
|
-
if needs_pad and
|
134
|
+
if needs_pad and unchanged_shape:
|
127
135
|
out = out[..., :-padding, :]
|
128
136
|
|
129
137
|
return out
|
@@ -690,7 +698,7 @@ class MemoryAsContextTransformer(Module):
|
|
690
698
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
691
699
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
692
700
|
|
693
|
-
x = inverse_segment(x
|
701
|
+
x = inverse_segment(x)
|
694
702
|
|
695
703
|
# splice out unneeded tokens from padding for longterm mems
|
696
704
|
|
@@ -759,14 +767,13 @@ class MemoryAsContextTransformer(Module):
|
|
759
767
|
mem_input, add_residual = mem_hyper_conn(x)
|
760
768
|
|
761
769
|
if not is_inferencing:
|
762
|
-
retrieved, mem_kv_aux_loss = mem(
|
770
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
763
771
|
mem_input,
|
764
772
|
return_aux_kv_loss = True
|
765
773
|
)
|
766
774
|
|
767
775
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
768
776
|
|
769
|
-
next_neural_mem_cache = (seq_len, None, None, None)
|
770
777
|
else:
|
771
778
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
772
779
|
mem_input,
|
@@ -836,7 +843,7 @@ class MemoryAsContextTransformer(Module):
|
|
836
843
|
|
837
844
|
x, _ = inverse_pack_mems(x)
|
838
845
|
|
839
|
-
x = inverse_segment(x
|
846
|
+
x = inverse_segment(x)
|
840
847
|
|
841
848
|
x = x[:, :seq_len]
|
842
849
|
|
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from collections import namedtuple
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn, cat, Tensor
|
@@ -33,13 +34,18 @@ w - num memory network weight parameters
|
|
33
34
|
|
34
35
|
LinearNoBias = partial(Linear, bias = False)
|
35
36
|
|
37
|
+
NeuralMemCache = namedtuple('NeuralMemCache', ['seq', 'cache_store_segment', 'states', 'updates'])
|
38
|
+
|
36
39
|
# functions
|
37
40
|
|
38
41
|
def exists(v):
|
39
42
|
return v is not None
|
40
43
|
|
41
|
-
def default(
|
42
|
-
|
44
|
+
def default(*args):
|
45
|
+
for arg in args:
|
46
|
+
if exists(arg):
|
47
|
+
return arg
|
48
|
+
return None
|
43
49
|
|
44
50
|
def xnor(x, y):
|
45
51
|
return not (x ^ y)
|
@@ -468,7 +474,12 @@ class NeuralMemory(Module):
|
|
468
474
|
weighted_loss = loss * loss_weights
|
469
475
|
return weighted_loss.sum(), weighted_loss.mean()
|
470
476
|
|
471
|
-
|
477
|
+
# two functions
|
478
|
+
|
479
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
480
|
+
|
481
|
+
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
|
482
|
+
self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
|
472
483
|
|
473
484
|
# queries for retrieving from the model
|
474
485
|
|
@@ -561,6 +572,7 @@ class NeuralMemory(Module):
|
|
561
572
|
seq,
|
562
573
|
weights: dict[str, Tensor],
|
563
574
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
575
|
+
prev_layer_updates: dict[str, Tensor] | None = None,
|
564
576
|
return_aux_kv_loss = False,
|
565
577
|
chunk_size = None,
|
566
578
|
value_residual = None
|
@@ -583,10 +595,25 @@ class NeuralMemory(Module):
|
|
583
595
|
|
584
596
|
seq = seq[:, :round_down_seq_len]
|
585
597
|
|
598
|
+
# per sample grad function
|
599
|
+
|
600
|
+
per_sample_grad_fn = self.per_sample_grad_fn
|
601
|
+
|
586
602
|
# weights of the memory network
|
587
603
|
|
588
604
|
weights = TensorDict(weights)
|
589
605
|
|
606
|
+
# allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
|
607
|
+
# think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
|
608
|
+
# improvise (or perhaps correcting to) a solution
|
609
|
+
|
610
|
+
if exists(prev_layer_updates):
|
611
|
+
prev_layer_updates = TensorDict(prev_layer_updates)
|
612
|
+
|
613
|
+
weights = weights + prev_layer_updates
|
614
|
+
|
615
|
+
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
616
|
+
|
590
617
|
# derive learned hparams for optimization of memory network
|
591
618
|
|
592
619
|
adaptive_lr = self.to_adaptive_step(seq)
|
@@ -633,9 +660,14 @@ class NeuralMemory(Module):
|
|
633
660
|
|
634
661
|
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
635
662
|
|
663
|
+
# flatten batch and time if surprise depends on previous layer memory model
|
664
|
+
|
665
|
+
if exists(prev_layer_updates):
|
666
|
+
weights = weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
667
|
+
|
636
668
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
637
669
|
|
638
|
-
grads, aux_kv_recon_loss =
|
670
|
+
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
639
671
|
|
640
672
|
grads = TensorDict(grads)
|
641
673
|
|
@@ -713,7 +745,8 @@ class NeuralMemory(Module):
|
|
713
745
|
self,
|
714
746
|
seq,
|
715
747
|
past_weights: dict[str, Tensor],
|
716
|
-
chunk_size = None
|
748
|
+
chunk_size = None,
|
749
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
717
750
|
):
|
718
751
|
chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
719
752
|
batch, seq_len = seq.shape[:2]
|
@@ -736,6 +769,9 @@ class NeuralMemory(Module):
|
|
736
769
|
|
737
770
|
curr_weights = TensorDict(past_weights)
|
738
771
|
|
772
|
+
if exists(prev_layer_updates):
|
773
|
+
curr_weights = curr_weights + TensorDict(prev_layer_updates)
|
774
|
+
|
739
775
|
# sequence Float['b n d'] to queries
|
740
776
|
|
741
777
|
queries = self.to_queries(seq)
|
@@ -781,6 +817,7 @@ class NeuralMemory(Module):
|
|
781
817
|
|
782
818
|
return values[:, :seq_len]
|
783
819
|
|
820
|
+
@torch.no_grad()
|
784
821
|
def forward_inference(
|
785
822
|
self,
|
786
823
|
token: Tensor,
|
@@ -813,7 +850,7 @@ class NeuralMemory(Module):
|
|
813
850
|
if curr_seq_len < self.chunk_size:
|
814
851
|
empty_mem = self.init_empty_memory_embed(batch, 1)
|
815
852
|
|
816
|
-
return empty_mem, (curr_seq_len, cache_store_seq, past_states, updates)
|
853
|
+
return empty_mem, NeuralMemCache(curr_seq_len, cache_store_seq, past_states, updates)
|
817
854
|
|
818
855
|
# store if storage sequence cache hits the chunk size
|
819
856
|
|
@@ -823,6 +860,8 @@ class NeuralMemory(Module):
|
|
823
860
|
if not exists(updates):
|
824
861
|
updates = weights.clone().zero_()
|
825
862
|
updates = updates.apply(lambda t: repeat(t, '... -> b 1 ...', b = batch))
|
863
|
+
else:
|
864
|
+
updates = updates.apply(lambda t: t[:, -1:])
|
826
865
|
|
827
866
|
if store_seq_cache_len == self.chunk_size:
|
828
867
|
|
@@ -841,7 +880,7 @@ class NeuralMemory(Module):
|
|
841
880
|
|
842
881
|
# next state tuple
|
843
882
|
|
844
|
-
next_state = (curr_seq_len, cache_store_seq, next_states, updates)
|
883
|
+
next_state = NeuralMemCache(curr_seq_len, cache_store_seq, next_states, updates)
|
845
884
|
|
846
885
|
return retrieved, next_state
|
847
886
|
|
@@ -854,13 +893,19 @@ class NeuralMemory(Module):
|
|
854
893
|
return_aux_kv_loss = False,
|
855
894
|
chunk_size = None,
|
856
895
|
store_chunk_size = None,
|
857
|
-
return_values = False
|
896
|
+
return_values = False,
|
897
|
+
return_next_state = False,
|
898
|
+
prev_layer_updates: dict[str, Tensor] | None = None
|
858
899
|
):
|
859
900
|
batch, seq_len = seq.shape[:2]
|
860
901
|
|
861
902
|
if seq_len < self.retrieve_chunk_size:
|
862
903
|
out = self.init_empty_memory_embed(batch, seq_len)
|
863
904
|
|
905
|
+
next_store_state = (seq_len, seq, None, None)
|
906
|
+
|
907
|
+
out = (out, next_store_state)
|
908
|
+
|
864
909
|
if not return_aux_kv_loss:
|
865
910
|
return out
|
866
911
|
|
@@ -869,17 +914,45 @@ class NeuralMemory(Module):
|
|
869
914
|
if not exists(mem_model_weights):
|
870
915
|
mem_model_weights = self.init_weights()
|
871
916
|
|
917
|
+
# store
|
918
|
+
|
872
919
|
store_seq = default(store_seq, seq)
|
873
|
-
store_chunk_size = default(store_chunk_size, chunk_size)
|
874
920
|
|
875
|
-
|
921
|
+
store_seq_len = store_seq.shape[-2]
|
922
|
+
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
923
|
+
remainder = store_seq_len % store_chunk_size
|
924
|
+
|
925
|
+
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(
|
926
|
+
store_seq,
|
927
|
+
mem_model_weights,
|
928
|
+
chunk_size = store_chunk_size,
|
929
|
+
prev_layer_updates = prev_layer_updates,
|
930
|
+
return_aux_kv_loss = True
|
931
|
+
)
|
932
|
+
|
933
|
+
# retrieve
|
934
|
+
|
935
|
+
retrieved = self.retrieve_memories(
|
936
|
+
seq,
|
937
|
+
mem_model_weights + updates,
|
938
|
+
chunk_size = chunk_size,
|
939
|
+
prev_layer_updates = prev_layer_updates
|
940
|
+
)
|
941
|
+
|
942
|
+
# determine state for the storing of memories
|
943
|
+
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
944
|
+
|
945
|
+
cache_store_seq = None
|
946
|
+
|
947
|
+
if remainder > 0:
|
948
|
+
cache_store_seq = store_seq[:, -remainder:]
|
876
949
|
|
877
|
-
|
950
|
+
next_store_state = NeuralMemCache(seq_len, cache_store_seq, next_state, updates)
|
878
951
|
|
879
|
-
output = retrieved
|
952
|
+
output = (retrieved, next_store_state)
|
880
953
|
|
881
954
|
if return_values:
|
882
|
-
output = (
|
955
|
+
output = (*output, values)
|
883
956
|
|
884
957
|
if not return_aux_kv_loss:
|
885
958
|
return output
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.33
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -82,7 +82,7 @@ mem = NeuralMemory(
|
|
82
82
|
).cuda()
|
83
83
|
|
84
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
85
|
-
retrieved = mem(seq)
|
85
|
+
retrieved, mem_state = mem(seq)
|
86
86
|
|
87
87
|
assert seq.shape == retrieved.shape
|
88
88
|
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
|
4
|
+
titans_pytorch/neural_memory.py,sha256=9dXpSaQYomc-ur-nEwej1nG9M5NqS0c3LBBP9jUIMPU,28352
|
5
|
+
titans_pytorch-0.1.33.dist-info/METADATA,sha256=A9BBoe0Sas2kxUcUi7w_RFl8-SIF1TLzPIRGuZlauFM,6826
|
6
|
+
titans_pytorch-0.1.33.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.33.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.33.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
|
4
|
-
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
-
titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
|
6
|
-
titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|