titans-pytorch 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +23 -13
- titans_pytorch/neural_memory.py +6 -1
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.5.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.5.dist-info/RECORD +9 -0
- titans_pytorch-0.3.4.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.5.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.5.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -30,6 +30,25 @@ class LayerNorm(Module):
|
|
30
30
|
|
31
31
|
return self.ln(x) * (gamma + 1.)
|
32
32
|
|
33
|
+
# norm + residual wrapper, as used in original TTT paper
|
34
|
+
# but could be removed
|
35
|
+
|
36
|
+
class ResidualNorm(Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
dim,
|
40
|
+
model: Module
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
self.norm = LayerNorm(dim)
|
44
|
+
self.model = model
|
45
|
+
|
46
|
+
def forward(self, x):
|
47
|
+
|
48
|
+
out = self.model(x)
|
49
|
+
|
50
|
+
return self.norm(out) + x
|
51
|
+
|
33
52
|
# memory mlp proposed in TTT
|
34
53
|
|
35
54
|
class MemoryMLP(Module):
|
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
|
|
45
64
|
|
46
65
|
self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
|
47
66
|
|
48
|
-
self.ln = LayerNorm(dim)
|
49
|
-
|
50
67
|
for weight in self.weights:
|
51
68
|
nn.init.xavier_uniform_(weight)
|
52
69
|
|
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
|
|
54
71
|
self,
|
55
72
|
x
|
56
73
|
):
|
57
|
-
residual = x
|
58
|
-
|
59
74
|
for ind, weight in enumerate(self.weights):
|
60
75
|
is_first = ind == 0
|
61
76
|
|
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
|
|
64
79
|
|
65
80
|
x = x @ weight
|
66
81
|
|
67
|
-
return
|
82
|
+
return x
|
68
83
|
|
69
84
|
# memory mlp, but with gated residual + final projection
|
70
85
|
|
@@ -97,7 +112,6 @@ class GatedResidualMemoryMLP(Module):
|
|
97
112
|
self,
|
98
113
|
x
|
99
114
|
):
|
100
|
-
residual = x
|
101
115
|
|
102
116
|
for weight1, weight2, to_gates in self.weights:
|
103
117
|
res = x
|
@@ -111,9 +125,7 @@ class GatedResidualMemoryMLP(Module):
|
|
111
125
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
112
126
|
x = res.lerp(branch_out, gates.sigmoid())
|
113
127
|
|
114
|
-
|
115
|
-
|
116
|
-
return self.ln(out) + residual
|
128
|
+
return x @ self.final_proj
|
117
129
|
|
118
130
|
# memory mlp with factorized weights
|
119
131
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -143,7 +155,6 @@ class FactorizedMemoryMLP(Module):
|
|
143
155
|
self,
|
144
156
|
x
|
145
157
|
):
|
146
|
-
residual = x
|
147
158
|
|
148
159
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
149
160
|
is_first = ind == 0
|
@@ -153,7 +164,7 @@ class FactorizedMemoryMLP(Module):
|
|
153
164
|
|
154
165
|
x = x @ weight1 @ weight2
|
155
166
|
|
156
|
-
return
|
167
|
+
return x
|
157
168
|
|
158
169
|
# improvised attention as memory module
|
159
170
|
|
@@ -182,7 +193,6 @@ class MemoryAttention(Module):
|
|
182
193
|
nn.init.xavier_uniform_(weight)
|
183
194
|
|
184
195
|
def forward(self, x):
|
185
|
-
residual = x
|
186
196
|
|
187
197
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
188
198
|
|
@@ -202,4 +212,4 @@ class MemoryAttention(Module):
|
|
202
212
|
h = F.gelu(x @ ffw1)
|
203
213
|
ff_out = h @ ffw2
|
204
214
|
|
205
|
-
return
|
215
|
+
return attn_out + ff_out
|
titans_pytorch/neural_memory.py
CHANGED
@@ -16,7 +16,8 @@ from tensordict import TensorDict
|
|
16
16
|
from titans_pytorch.associative_scan import AssocScan
|
17
17
|
|
18
18
|
from titans_pytorch.memory_models import(
|
19
|
-
MemoryMLP
|
19
|
+
MemoryMLP,
|
20
|
+
ResidualNorm
|
20
21
|
)
|
21
22
|
|
22
23
|
import einx
|
@@ -234,6 +235,7 @@ class NeuralMemory(Module):
|
|
234
235
|
init_decay_bias = None,
|
235
236
|
accept_weight_residual = False,
|
236
237
|
gated_transition = False,
|
238
|
+
mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
|
237
239
|
default_model_kwargs: dict = dict(
|
238
240
|
depth = 2,
|
239
241
|
expansion_factor = 4.
|
@@ -304,6 +306,9 @@ class NeuralMemory(Module):
|
|
304
306
|
|
305
307
|
# the memory is the weights of the model
|
306
308
|
|
309
|
+
if mem_model_norm_add_residual:
|
310
|
+
model = ResidualNorm(dim = dim_head, model = model)
|
311
|
+
|
307
312
|
self.memory_model = model
|
308
313
|
|
309
314
|
mem_model_params = dict(model.named_parameters())
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=2fma9u0NQmDabgbpG6CLDGBRYzX99yIDQCSYIB0etkU,4989
|
5
|
+
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
+
titans_pytorch-0.3.5.dist-info/METADATA,sha256=R6EL4q-zgW7DV5OyLzqz5XP2IvLNpJkaBylwH8GsyII,6815
|
7
|
+
titans_pytorch-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.5.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
|
5
|
-
titans_pytorch/neural_memory.py,sha256=9eyeEvYsP5OFlwLDRyVut99uVYGvXAElFPabVoZnGJw,27063
|
6
|
-
titans_pytorch-0.3.4.dist-info/METADATA,sha256=2ZD_DovSYkVejsTWHq7_IOTN-Je0of1f-HOiojaQBhQ,6815
|
7
|
-
titans_pytorch-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|