titans-pytorch 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +31 -13
- titans_pytorch/neural_memory.py +27 -15
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.7.dist-info}/METADATA +3 -3
- titans_pytorch-0.2.7.dist-info/RECORD +9 -0
- titans_pytorch-0.2.5.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -3,6 +3,13 @@ from torch import nn, cat
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
5
|
|
6
|
+
# functions
|
7
|
+
|
8
|
+
def l2norm(t):
|
9
|
+
return F.normalize(t, dim = -1)
|
10
|
+
|
11
|
+
# memory mlp proposed in TTT
|
12
|
+
|
6
13
|
class MemoryMLP(Module):
|
7
14
|
def __init__(
|
8
15
|
self,
|
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
|
|
19
26
|
self,
|
20
27
|
x
|
21
28
|
):
|
29
|
+
residual = x
|
30
|
+
|
22
31
|
for ind, weight in enumerate(self.weights):
|
23
32
|
is_first = ind == 0
|
24
33
|
|
25
34
|
if not is_first:
|
26
|
-
x = F.
|
35
|
+
x = F.gelu(x)
|
27
36
|
|
28
37
|
x = x @ weight
|
29
38
|
|
30
|
-
return x
|
39
|
+
return x + residual
|
31
40
|
|
32
41
|
# memory mlp, but with gated residual + final projection
|
33
42
|
|
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
|
|
36
45
|
self,
|
37
46
|
dim,
|
38
47
|
depth,
|
39
|
-
expansion_factor =
|
48
|
+
expansion_factor = 4.
|
40
49
|
):
|
41
50
|
super().__init__()
|
42
51
|
dim_hidden = int(dim * expansion_factor)
|
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
|
|
58
67
|
self,
|
59
68
|
x
|
60
69
|
):
|
70
|
+
residual = x
|
71
|
+
|
61
72
|
for weight1, weight2, to_gates in self.weights:
|
62
73
|
res = x
|
63
74
|
|
64
75
|
hidden = x @ weight1
|
65
|
-
hidden = F.
|
76
|
+
hidden = F.gelu(hidden)
|
66
77
|
branch_out = hidden @ weight2
|
67
78
|
|
68
79
|
# gated residual
|
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
|
|
70
81
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
71
82
|
x = res.lerp(branch_out, gates.sigmoid())
|
72
83
|
|
73
|
-
|
84
|
+
out = x @ self.final_proj
|
85
|
+
|
86
|
+
return out + residual
|
74
87
|
|
75
88
|
# memory mlp with factorized weights
|
76
89
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
|
|
98
111
|
self,
|
99
112
|
x
|
100
113
|
):
|
114
|
+
residual = x
|
115
|
+
|
101
116
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
102
117
|
is_first = ind == 0
|
103
118
|
|
104
119
|
if not is_first:
|
105
|
-
x = F.
|
120
|
+
x = F.gelu(x)
|
106
121
|
|
107
122
|
x = x @ weight1 @ weight2
|
108
123
|
|
109
|
-
return x
|
124
|
+
return x + residual
|
110
125
|
|
111
126
|
# improvised attention as memory module
|
112
127
|
|
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
|
|
133
148
|
nn.init.xavier_uniform_(weight)
|
134
149
|
|
135
150
|
def forward(self, x):
|
151
|
+
residual = x
|
152
|
+
|
136
153
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
137
154
|
|
138
|
-
q =
|
139
|
-
k =
|
155
|
+
q = l2norm(x @ wq)
|
156
|
+
k = l2norm(x @ wk)
|
140
157
|
v = x @ wv
|
141
158
|
|
142
159
|
attn_out = F.scaled_dot_product_attention(
|
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
|
|
145
162
|
is_causal = True
|
146
163
|
)
|
147
164
|
|
148
|
-
|
165
|
+
# parallel attention + feedforward block
|
166
|
+
# as in PaLM + Gpt-J
|
149
167
|
|
150
|
-
h = F.
|
151
|
-
|
168
|
+
h = F.gelu(x @ ffw1)
|
169
|
+
ff_out = h @ ffw2
|
152
170
|
|
153
|
-
return
|
171
|
+
return attn_out + ff_out + residual
|
titans_pytorch/neural_memory.py
CHANGED
@@ -51,6 +51,9 @@ def default(*args):
|
|
51
51
|
return arg
|
52
52
|
return None
|
53
53
|
|
54
|
+
def identity(t):
|
55
|
+
return t
|
56
|
+
|
54
57
|
def xnor(x, y):
|
55
58
|
return not (x ^ y)
|
56
59
|
|
@@ -64,9 +67,6 @@ def safe_cat(inputs, dim = -2):
|
|
64
67
|
|
65
68
|
return cat(inputs, dim = dim)
|
66
69
|
|
67
|
-
def identity(t):
|
68
|
-
return t
|
69
|
-
|
70
70
|
def dict_get_shape(td):
|
71
71
|
return {k: v.shape for k, v in td.items()}
|
72
72
|
|
@@ -454,14 +454,14 @@ class NeuralMemory(Module):
|
|
454
454
|
|
455
455
|
weights = TensorDict(weights)
|
456
456
|
|
457
|
-
# allow for neural memory of a previous layer
|
458
|
-
|
459
|
-
|
457
|
+
# allow for neural memory of a previous layer to influence surprise of current layer
|
458
|
+
|
459
|
+
weights_for_surprise = weights
|
460
460
|
|
461
461
|
if exists(prev_layer_updates):
|
462
462
|
prev_layer_updates = TensorDict(prev_layer_updates)
|
463
463
|
|
464
|
-
|
464
|
+
weights_for_surprise = weights_for_surprise + prev_layer_updates
|
465
465
|
|
466
466
|
per_sample_grad_fn = self.per_sample_grad_fn_expanded_weights # the weights will now have a batch * chunk dimension
|
467
467
|
|
@@ -506,11 +506,11 @@ class NeuralMemory(Module):
|
|
506
506
|
# flatten batch and time if surprise depends on previous layer memory model
|
507
507
|
|
508
508
|
if exists(prev_layer_updates):
|
509
|
-
|
509
|
+
weights_for_surprise = weights_for_surprise.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
510
510
|
|
511
511
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
512
512
|
|
513
|
-
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(
|
513
|
+
grads, aux_kv_recon_loss = per_sample_grad_fn(dict(weights_for_surprise), keys, adaptive_lr, values)
|
514
514
|
|
515
515
|
grads = TensorDict(grads)
|
516
516
|
|
@@ -536,7 +536,15 @@ class NeuralMemory(Module):
|
|
536
536
|
|
537
537
|
if not exists(past_state):
|
538
538
|
empty_dict = {key: None for key in weights.keys()}
|
539
|
-
|
539
|
+
|
540
|
+
# minibatch_init_weight corresponds to W0 in figure 7 of TTT paper
|
541
|
+
|
542
|
+
minibatch_init_weight = weights
|
543
|
+
|
544
|
+
if dict_get_shape(weights) == self.init_weight_shape:
|
545
|
+
minibatch_init_weight = weights.apply(lambda t: repeat(t, '... -> b 1 (...)', b = batch * heads))
|
546
|
+
|
547
|
+
past_state = (minibatch_init_weight, empty_dict)
|
540
548
|
|
541
549
|
past_last_update, past_last_momentum = past_state
|
542
550
|
|
@@ -734,7 +742,7 @@ class NeuralMemory(Module):
|
|
734
742
|
|
735
743
|
# retrieve
|
736
744
|
|
737
|
-
retrieved = self.retrieve_memories(token,
|
745
|
+
retrieved = self.retrieve_memories(token, weights, chunk_size = 1)
|
738
746
|
|
739
747
|
# next state tuple
|
740
748
|
|
@@ -785,14 +793,18 @@ class NeuralMemory(Module):
|
|
785
793
|
|
786
794
|
# retrieve
|
787
795
|
|
788
|
-
|
789
|
-
|
796
|
+
retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
797
|
+
|
798
|
+
if retrieve_chunk_size != 1:
|
799
|
+
if exists(prev_layer_updates):
|
800
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
801
|
+
|
802
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
790
803
|
|
791
|
-
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
792
804
|
|
793
805
|
retrieved = self.retrieve_memories(
|
794
806
|
seq,
|
795
|
-
mem_model_weights
|
807
|
+
mem_model_weights,
|
796
808
|
chunk_size = chunk_size,
|
797
809
|
prev_layer_updates = prev_layer_updates
|
798
810
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.7
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -27,7 +27,7 @@ License: MIT License
|
|
27
27
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
28
28
|
SOFTWARE.
|
29
29
|
License-File: LICENSE
|
30
|
-
Keywords: artificial intelligence,deep learning,linear attention,
|
30
|
+
Keywords: artificial intelligence,deep learning,linear attention,memory,test time training
|
31
31
|
Classifier: Development Status :: 4 - Beta
|
32
32
|
Classifier: Intended Audience :: Developers
|
33
33
|
Classifier: License :: OSI Approved :: MIT License
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch (wip)
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
|
5
|
+
titans_pytorch/neural_memory.py,sha256=WAeR-nOpy1XbBP590By1-tCgirulqPbFGut4H1B77-g,24910
|
6
|
+
titans_pytorch-0.2.7.dist-info/METADATA,sha256=ndFb28pAe8xWmNU6oncV8VJDDPImo3aCuBv0d0JylIs,6811
|
7
|
+
titans_pytorch-0.2.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.7.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
-
titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
|
5
|
-
titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
|
6
|
-
titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
|
7
|
-
titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|