titans-pytorch 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -1
- titans_pytorch/mac_transformer.py +15 -8
- titans_pytorch/{titans.py → neural_memory.py} +53 -8
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.32.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.32.dist-info/RECORD +8 -0
- titans_pytorch-0.1.31.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.32.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.31.dist-info → titans_pytorch-0.1.32.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
@@ -65,7 +65,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
65
65
|
|
66
66
|
# proposed neural memory
|
67
67
|
|
68
|
-
from titans_pytorch.
|
68
|
+
from titans_pytorch.neural_memory import NeuralMemory
|
69
69
|
|
70
70
|
# constants
|
71
71
|
|
@@ -106,7 +106,11 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
106
106
|
zeros = ((0, 0) * dims_from_right)
|
107
107
|
return F.pad(t, (*zeros, *pad), value = value)
|
108
108
|
|
109
|
-
def pad_and_segment_with_inverse(
|
109
|
+
def pad_and_segment_with_inverse(
|
110
|
+
seq,
|
111
|
+
segment_len,
|
112
|
+
fold_into_batch = True,
|
113
|
+
):
|
110
114
|
batch, seq_len = seq.shape[:2]
|
111
115
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
112
116
|
|
@@ -119,11 +123,15 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
119
123
|
if fold_into_batch:
|
120
124
|
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
121
125
|
|
122
|
-
|
126
|
+
shape = seq.shape
|
127
|
+
|
128
|
+
def inverse(out):
|
129
|
+
unchanged_shape = out.shape == shape
|
130
|
+
|
123
131
|
if fold_into_batch:
|
124
132
|
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
125
133
|
|
126
|
-
if needs_pad and
|
134
|
+
if needs_pad and unchanged_shape:
|
127
135
|
out = out[..., :-padding, :]
|
128
136
|
|
129
137
|
return out
|
@@ -690,7 +698,7 @@ class MemoryAsContextTransformer(Module):
|
|
690
698
|
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
691
699
|
x, inverse_pack_mems = pack_with_inverse((x, mems), 'b * d')
|
692
700
|
|
693
|
-
x = inverse_segment(x
|
701
|
+
x = inverse_segment(x)
|
694
702
|
|
695
703
|
# splice out unneeded tokens from padding for longterm mems
|
696
704
|
|
@@ -759,14 +767,13 @@ class MemoryAsContextTransformer(Module):
|
|
759
767
|
mem_input, add_residual = mem_hyper_conn(x)
|
760
768
|
|
761
769
|
if not is_inferencing:
|
762
|
-
retrieved, mem_kv_aux_loss = mem(
|
770
|
+
(retrieved, next_neural_mem_cache), mem_kv_aux_loss = mem(
|
763
771
|
mem_input,
|
764
772
|
return_aux_kv_loss = True
|
765
773
|
)
|
766
774
|
|
767
775
|
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
768
776
|
|
769
|
-
next_neural_mem_cache = (seq_len, None, None, None)
|
770
777
|
else:
|
771
778
|
retrieved, next_neural_mem_cache = mem.forward_inference(
|
772
779
|
mem_input,
|
@@ -836,7 +843,7 @@ class MemoryAsContextTransformer(Module):
|
|
836
843
|
|
837
844
|
x, _ = inverse_pack_mems(x)
|
838
845
|
|
839
|
-
x = inverse_segment(x
|
846
|
+
x = inverse_segment(x)
|
840
847
|
|
841
848
|
x = x[:, :seq_len]
|
842
849
|
|
@@ -38,8 +38,11 @@ LinearNoBias = partial(Linear, bias = False)
|
|
38
38
|
def exists(v):
|
39
39
|
return v is not None
|
40
40
|
|
41
|
-
def default(
|
42
|
-
|
41
|
+
def default(*args):
|
42
|
+
for arg in args:
|
43
|
+
if exists(arg):
|
44
|
+
return arg
|
45
|
+
return None
|
43
46
|
|
44
47
|
def xnor(x, y):
|
45
48
|
return not (x ^ y)
|
@@ -468,7 +471,12 @@ class NeuralMemory(Module):
|
|
468
471
|
weighted_loss = loss * loss_weights
|
469
472
|
return weighted_loss.sum(), weighted_loss.mean()
|
470
473
|
|
471
|
-
|
474
|
+
# two functions
|
475
|
+
|
476
|
+
grad_fn = grad(forward_and_loss, has_aux = True)
|
477
|
+
|
478
|
+
self.per_sample_grad_fn = vmap(grad_fn, in_dims = (None, 0, 0, 0))
|
479
|
+
self.per_sample_grad_fn_expanded_weights = vmap(grad_fn, in_dims = (0,) * 4)
|
472
480
|
|
473
481
|
# queries for retrieving from the model
|
474
482
|
|
@@ -561,6 +569,7 @@ class NeuralMemory(Module):
|
|
561
569
|
seq,
|
562
570
|
weights: dict[str, Tensor],
|
563
571
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
572
|
+
prev_layer_updates: dict[str, Tensor] | None = None,
|
564
573
|
return_aux_kv_loss = False,
|
565
574
|
chunk_size = None,
|
566
575
|
value_residual = None
|
@@ -583,10 +592,25 @@ class NeuralMemory(Module):
|
|
583
592
|
|
584
593
|
seq = seq[:, :round_down_seq_len]
|
585
594
|
|
595
|
+
# per sample grad function
|
596
|
+
|
597
|
+
per_sample_grad_fn = self.per_sample_grad_fn
|
598
|
+
|
586
599
|
# weights of the memory network
|
587
600
|
|
588
601
|
weights = TensorDict(weights)
|
589
602
|
|
603
|
+
# allow for neural memory of a previous layer and the past to produce gradients that become the weights of the current one generating the surprise
|
604
|
+
# think this is necessary otherwise the memory model is static (unless if paper is misunderstood)
|
605
|
+
# improvise (or perhaps correcting to) a solution
|
606
|
+
|
607
|
+
if exists(prev_layer_updates):
|
608
|
+
prev_layer_updates = TensorDict(weights)
|
609
|
+
|
610
|
+
weights = weights + prev_layer_updates
|
611
|
+
|
612
|
+
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
613
|
+
|
590
614
|
# derive learned hparams for optimization of memory network
|
591
615
|
|
592
616
|
adaptive_lr = self.to_adaptive_step(seq)
|
@@ -635,7 +659,7 @@ class NeuralMemory(Module):
|
|
635
659
|
|
636
660
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
637
661
|
|
638
|
-
grads, aux_kv_recon_loss =
|
662
|
+
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights), keys, adaptive_lr, values)
|
639
663
|
|
640
664
|
grads = TensorDict(grads)
|
641
665
|
|
@@ -781,6 +805,7 @@ class NeuralMemory(Module):
|
|
781
805
|
|
782
806
|
return values[:, :seq_len]
|
783
807
|
|
808
|
+
@torch.no_grad()
|
784
809
|
def forward_inference(
|
785
810
|
self,
|
786
811
|
token: Tensor,
|
@@ -854,13 +879,18 @@ class NeuralMemory(Module):
|
|
854
879
|
return_aux_kv_loss = False,
|
855
880
|
chunk_size = None,
|
856
881
|
store_chunk_size = None,
|
857
|
-
return_values = False
|
882
|
+
return_values = False,
|
883
|
+
return_next_state = False
|
858
884
|
):
|
859
885
|
batch, seq_len = seq.shape[:2]
|
860
886
|
|
861
887
|
if seq_len < self.retrieve_chunk_size:
|
862
888
|
out = self.init_empty_memory_embed(batch, seq_len)
|
863
889
|
|
890
|
+
next_store_state = (seq_len, seq, None, None)
|
891
|
+
|
892
|
+
out = (out, next_store_state)
|
893
|
+
|
864
894
|
if not return_aux_kv_loss:
|
865
895
|
return out
|
866
896
|
|
@@ -870,16 +900,31 @@ class NeuralMemory(Module):
|
|
870
900
|
mem_model_weights = self.init_weights()
|
871
901
|
|
872
902
|
store_seq = default(store_seq, seq)
|
873
|
-
|
903
|
+
|
904
|
+
store_seq_len = store_seq.shape[-2]
|
905
|
+
store_chunk_size = default(store_chunk_size, chunk_size, self.store_chunk_size)
|
906
|
+
remainder = store_seq_len % store_chunk_size
|
874
907
|
|
875
908
|
(updates, next_state, values), aux_kv_recon_loss = self.store_memories(store_seq, mem_model_weights, chunk_size = store_chunk_size, return_aux_kv_loss = True)
|
876
909
|
|
877
910
|
retrieved = self.retrieve_memories(seq, mem_model_weights + updates, chunk_size = chunk_size)
|
878
911
|
|
879
|
-
|
912
|
+
# determine state for the storing of memories
|
913
|
+
# for transformer-xl like training with neural memory as well as inferencing with initial prompt
|
914
|
+
|
915
|
+
cache_store_seq = None
|
916
|
+
|
917
|
+
if remainder > 0:
|
918
|
+
cache_store_seq = store_seq[:, -remainder:]
|
919
|
+
|
920
|
+
updates = updates.apply(lambda t: t[:, -1:])
|
921
|
+
|
922
|
+
next_store_state = (seq_len, cache_store_seq, next_state, updates)
|
923
|
+
|
924
|
+
output = (retrieved, next_store_state)
|
880
925
|
|
881
926
|
if return_values:
|
882
|
-
output = (
|
927
|
+
output = (*output, values)
|
883
928
|
|
884
929
|
if not return_aux_kv_loss:
|
885
930
|
return output
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.32
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -82,7 +82,7 @@ mem = NeuralMemory(
|
|
82
82
|
).cuda()
|
83
83
|
|
84
84
|
seq = torch.randn(2, 1024, 384).cuda()
|
85
|
-
retrieved = mem(seq)
|
85
|
+
retrieved, mem_state = mem(seq)
|
86
86
|
|
87
87
|
assert seq.shape == retrieved.shape
|
88
88
|
```
|
@@ -0,0 +1,8 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=rMT99CPQFH4Gudp0FmVPWGKfhBf6xksGEaFEcOVdqjs,230
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Cui-hCl6X4UVGmuyoKCSKWbag9Yrc-a2MrfVkHM-z0A,24828
|
4
|
+
titans_pytorch/neural_memory.py,sha256=Vfo1z1VztPDDXgFxjkiyOP29daDE7KTdnZeWXifvCJI,27456
|
5
|
+
titans_pytorch-0.1.32.dist-info/METADATA,sha256=_HPPht8nhLwH9GzLyZI-fh8JBSEoSxkENCSU2xuU_6A,6826
|
6
|
+
titans_pytorch-0.1.32.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
+
titans_pytorch-0.1.32.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
+
titans_pytorch-0.1.32.dist-info/RECORD,,
|
@@ -1,8 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=pKFRL_ISoHEKUyfssKwfBfwFO2eQN9objJmxLrNsYrU,24838
|
4
|
-
titans_pytorch/titans.py,sha256=6B8ioP26RTja5kVFMsorAnM9CcxIUySJS9RZBlDPI2s,25825
|
5
|
-
titans_pytorch-0.1.31.dist-info/METADATA,sha256=9ejOFuH2B2-yCRFK4x_C1DONPxecW8VcjEUeRh9OzXg,6815
|
6
|
-
titans_pytorch-0.1.31.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
7
|
-
titans_pytorch-0.1.31.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
8
|
-
titans_pytorch-0.1.31.dist-info/RECORD,,
|
File without changes
|
File without changes
|