titans-pytorch 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +31 -13
- titans_pytorch/neural_memory.py +7 -3
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.6.dist-info}/METADATA +2 -2
- titans_pytorch-0.2.6.dist-info/RECORD +9 -0
- titans_pytorch-0.2.5.dist-info/RECORD +0 -9
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.2.5.dist-info → titans_pytorch-0.2.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -3,6 +3,13 @@ from torch import nn, cat
|
|
3
3
|
import torch.nn.functional as F
|
4
4
|
from torch.nn import Module, ModuleList, Parameter, ParameterList
|
5
5
|
|
6
|
+
# functions
|
7
|
+
|
8
|
+
def l2norm(t):
|
9
|
+
return F.normalize(t, dim = -1)
|
10
|
+
|
11
|
+
# memory mlp proposed in TTT
|
12
|
+
|
6
13
|
class MemoryMLP(Module):
|
7
14
|
def __init__(
|
8
15
|
self,
|
@@ -19,15 +26,17 @@ class MemoryMLP(Module):
|
|
19
26
|
self,
|
20
27
|
x
|
21
28
|
):
|
29
|
+
residual = x
|
30
|
+
|
22
31
|
for ind, weight in enumerate(self.weights):
|
23
32
|
is_first = ind == 0
|
24
33
|
|
25
34
|
if not is_first:
|
26
|
-
x = F.
|
35
|
+
x = F.gelu(x)
|
27
36
|
|
28
37
|
x = x @ weight
|
29
38
|
|
30
|
-
return x
|
39
|
+
return x + residual
|
31
40
|
|
32
41
|
# memory mlp, but with gated residual + final projection
|
33
42
|
|
@@ -36,7 +45,7 @@ class GatedResidualMemoryMLP(Module):
|
|
36
45
|
self,
|
37
46
|
dim,
|
38
47
|
depth,
|
39
|
-
expansion_factor =
|
48
|
+
expansion_factor = 4.
|
40
49
|
):
|
41
50
|
super().__init__()
|
42
51
|
dim_hidden = int(dim * expansion_factor)
|
@@ -58,11 +67,13 @@ class GatedResidualMemoryMLP(Module):
|
|
58
67
|
self,
|
59
68
|
x
|
60
69
|
):
|
70
|
+
residual = x
|
71
|
+
|
61
72
|
for weight1, weight2, to_gates in self.weights:
|
62
73
|
res = x
|
63
74
|
|
64
75
|
hidden = x @ weight1
|
65
|
-
hidden = F.
|
76
|
+
hidden = F.gelu(hidden)
|
66
77
|
branch_out = hidden @ weight2
|
67
78
|
|
68
79
|
# gated residual
|
@@ -70,7 +81,9 @@ class GatedResidualMemoryMLP(Module):
|
|
70
81
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
71
82
|
x = res.lerp(branch_out, gates.sigmoid())
|
72
83
|
|
73
|
-
|
84
|
+
out = x @ self.final_proj
|
85
|
+
|
86
|
+
return out + residual
|
74
87
|
|
75
88
|
# memory mlp with factorized weights
|
76
89
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -98,15 +111,17 @@ class FactorizedMemoryMLP(Module):
|
|
98
111
|
self,
|
99
112
|
x
|
100
113
|
):
|
114
|
+
residual = x
|
115
|
+
|
101
116
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
102
117
|
is_first = ind == 0
|
103
118
|
|
104
119
|
if not is_first:
|
105
|
-
x = F.
|
120
|
+
x = F.gelu(x)
|
106
121
|
|
107
122
|
x = x @ weight1 @ weight2
|
108
123
|
|
109
|
-
return x
|
124
|
+
return x + residual
|
110
125
|
|
111
126
|
# improvised attention as memory module
|
112
127
|
|
@@ -133,10 +148,12 @@ class MemoryAttention(Module):
|
|
133
148
|
nn.init.xavier_uniform_(weight)
|
134
149
|
|
135
150
|
def forward(self, x):
|
151
|
+
residual = x
|
152
|
+
|
136
153
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
137
154
|
|
138
|
-
q =
|
139
|
-
k =
|
155
|
+
q = l2norm(x @ wq)
|
156
|
+
k = l2norm(x @ wk)
|
140
157
|
v = x @ wv
|
141
158
|
|
142
159
|
attn_out = F.scaled_dot_product_attention(
|
@@ -145,9 +162,10 @@ class MemoryAttention(Module):
|
|
145
162
|
is_causal = True
|
146
163
|
)
|
147
164
|
|
148
|
-
|
165
|
+
# parallel attention + feedforward block
|
166
|
+
# as in PaLM + Gpt-J
|
149
167
|
|
150
|
-
h = F.
|
151
|
-
|
168
|
+
h = F.gelu(x @ ffw1)
|
169
|
+
ff_out = h @ ffw2
|
152
170
|
|
153
|
-
return
|
171
|
+
return attn_out + ff_out + residual
|
titans_pytorch/neural_memory.py
CHANGED
@@ -785,10 +785,14 @@ class NeuralMemory(Module):
|
|
785
785
|
|
786
786
|
# retrieve
|
787
787
|
|
788
|
-
|
789
|
-
|
788
|
+
retrieve_chunk_size = default(chunk_size, self.retrieve_chunk_size)
|
789
|
+
|
790
|
+
if retrieve_chunk_size != 1:
|
791
|
+
if exists(prev_layer_updates):
|
792
|
+
prev_layer_updates = prev_layer_updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
793
|
+
|
794
|
+
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
790
795
|
|
791
|
-
updates = updates.apply(lambda t: pad_at_dim(t, (1, 0), dim = 1))
|
792
796
|
|
793
797
|
retrieved = self.retrieve_memories(
|
794
798
|
seq,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch (wip)
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
+
titans_pytorch/memory_models.py,sha256=CD8pQ-IUfTDvPmekuPTsZHE3Vy265QtbiUn_siJhA78,4064
|
5
|
+
titans_pytorch/neural_memory.py,sha256=UNST32JulrDw3_dPljSFU3ZCLofDH-KBoFxx8j1Oii4,24733
|
6
|
+
titans_pytorch-0.2.6.dist-info/METADATA,sha256=ZsjIkdKM2zanf74Q9EOqhCreVIQsdRNnOEIRBTISwAI,6825
|
7
|
+
titans_pytorch-0.2.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.2.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.2.6.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=UOJAMv7nTgkefBB7M7K3U0NnFkz75tFRG5WLXRdfnLw,26039
|
4
|
-
titans_pytorch/memory_models.py,sha256=Ew28waD9gf1wn-5Nkdc676u1I92IqzaOAw-tv0JXMwc,3777
|
5
|
-
titans_pytorch/neural_memory.py,sha256=YiBsMiqYn-Hva4yhxfaqkGV857vZIASxi5Z0TT0FC10,24606
|
6
|
-
titans_pytorch-0.2.5.dist-info/METADATA,sha256=x3RePuTDf3rUT3vtvge1X3Ry18Y3tV_swCgycbtSCjQ,6819
|
7
|
-
titans_pytorch-0.2.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.2.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.2.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|