titans-pytorch 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/memory_models.py +23 -19
- titans_pytorch/neural_memory.py +6 -1
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.6.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.6.dist-info/RECORD +9 -0
- titans_pytorch-0.3.4.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.4.dist-info → titans_pytorch-0.3.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/memory_models.py
CHANGED
@@ -30,6 +30,25 @@ class LayerNorm(Module):
|
|
30
30
|
|
31
31
|
return self.ln(x) * (gamma + 1.)
|
32
32
|
|
33
|
+
# norm + residual wrapper, as used in original TTT paper
|
34
|
+
# but could be removed
|
35
|
+
|
36
|
+
class ResidualNorm(Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
dim,
|
40
|
+
model: Module
|
41
|
+
):
|
42
|
+
super().__init__()
|
43
|
+
self.norm = LayerNorm(dim)
|
44
|
+
self.model = model
|
45
|
+
|
46
|
+
def forward(self, x):
|
47
|
+
|
48
|
+
out = self.model(x)
|
49
|
+
|
50
|
+
return self.norm(out) + x
|
51
|
+
|
33
52
|
# memory mlp proposed in TTT
|
34
53
|
|
35
54
|
class MemoryMLP(Module):
|
@@ -45,8 +64,6 @@ class MemoryMLP(Module):
|
|
45
64
|
|
46
65
|
self.weights = ParameterList([Parameter(torch.randn(dim_in, dim_out)) for dim_in, dim_out in zip(dims[:-1], dims[1:])])
|
47
66
|
|
48
|
-
self.ln = LayerNorm(dim)
|
49
|
-
|
50
67
|
for weight in self.weights:
|
51
68
|
nn.init.xavier_uniform_(weight)
|
52
69
|
|
@@ -54,8 +71,6 @@ class MemoryMLP(Module):
|
|
54
71
|
self,
|
55
72
|
x
|
56
73
|
):
|
57
|
-
residual = x
|
58
|
-
|
59
74
|
for ind, weight in enumerate(self.weights):
|
60
75
|
is_first = ind == 0
|
61
76
|
|
@@ -64,7 +79,7 @@ class MemoryMLP(Module):
|
|
64
79
|
|
65
80
|
x = x @ weight
|
66
81
|
|
67
|
-
return
|
82
|
+
return x
|
68
83
|
|
69
84
|
# memory mlp, but with gated residual + final projection
|
70
85
|
|
@@ -88,8 +103,6 @@ class GatedResidualMemoryMLP(Module):
|
|
88
103
|
|
89
104
|
self.final_proj = Parameter(torch.randn(dim, dim))
|
90
105
|
|
91
|
-
self.ln = LayerNorm(dim)
|
92
|
-
|
93
106
|
for param in self.parameters():
|
94
107
|
nn.init.xavier_uniform_(param)
|
95
108
|
|
@@ -97,7 +110,6 @@ class GatedResidualMemoryMLP(Module):
|
|
97
110
|
self,
|
98
111
|
x
|
99
112
|
):
|
100
|
-
residual = x
|
101
113
|
|
102
114
|
for weight1, weight2, to_gates in self.weights:
|
103
115
|
res = x
|
@@ -111,9 +123,7 @@ class GatedResidualMemoryMLP(Module):
|
|
111
123
|
gates = cat((branch_out, res), dim = -1) @ to_gates
|
112
124
|
x = res.lerp(branch_out, gates.sigmoid())
|
113
125
|
|
114
|
-
|
115
|
-
|
116
|
-
return self.ln(out) + residual
|
126
|
+
return x @ self.final_proj
|
117
127
|
|
118
128
|
# memory mlp with factorized weights
|
119
129
|
# so can tradeoff capacity for smaller chunk sizes
|
@@ -133,8 +143,6 @@ class FactorizedMemoryMLP(Module):
|
|
133
143
|
]) for _ in range(depth)
|
134
144
|
])
|
135
145
|
|
136
|
-
self.ln = LayerNorm(dim)
|
137
|
-
|
138
146
|
for weight1, weight2 in self.weights:
|
139
147
|
nn.init.xavier_uniform_(weight1)
|
140
148
|
nn.init.xavier_uniform_(weight2)
|
@@ -143,7 +151,6 @@ class FactorizedMemoryMLP(Module):
|
|
143
151
|
self,
|
144
152
|
x
|
145
153
|
):
|
146
|
-
residual = x
|
147
154
|
|
148
155
|
for ind, (weight1, weight2) in enumerate(self.weights):
|
149
156
|
is_first = ind == 0
|
@@ -153,7 +160,7 @@ class FactorizedMemoryMLP(Module):
|
|
153
160
|
|
154
161
|
x = x @ weight1 @ weight2
|
155
162
|
|
156
|
-
return
|
163
|
+
return x
|
157
164
|
|
158
165
|
# improvised attention as memory module
|
159
166
|
|
@@ -176,13 +183,10 @@ class MemoryAttention(Module):
|
|
176
183
|
nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
177
184
|
])
|
178
185
|
|
179
|
-
self.ln = LayerNorm(dim)
|
180
|
-
|
181
186
|
for weight in self.weights:
|
182
187
|
nn.init.xavier_uniform_(weight)
|
183
188
|
|
184
189
|
def forward(self, x):
|
185
|
-
residual = x
|
186
190
|
|
187
191
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
188
192
|
|
@@ -202,4 +206,4 @@ class MemoryAttention(Module):
|
|
202
206
|
h = F.gelu(x @ ffw1)
|
203
207
|
ff_out = h @ ffw2
|
204
208
|
|
205
|
-
return
|
209
|
+
return attn_out + ff_out
|
titans_pytorch/neural_memory.py
CHANGED
@@ -16,7 +16,8 @@ from tensordict import TensorDict
|
|
16
16
|
from titans_pytorch.associative_scan import AssocScan
|
17
17
|
|
18
18
|
from titans_pytorch.memory_models import(
|
19
|
-
MemoryMLP
|
19
|
+
MemoryMLP,
|
20
|
+
ResidualNorm
|
20
21
|
)
|
21
22
|
|
22
23
|
import einx
|
@@ -234,6 +235,7 @@ class NeuralMemory(Module):
|
|
234
235
|
init_decay_bias = None,
|
235
236
|
accept_weight_residual = False,
|
236
237
|
gated_transition = False,
|
238
|
+
mem_model_norm_add_residual = True, # by default, layernorm output and add residual as proposed in TTT paper, but could be removed
|
237
239
|
default_model_kwargs: dict = dict(
|
238
240
|
depth = 2,
|
239
241
|
expansion_factor = 4.
|
@@ -304,6 +306,9 @@ class NeuralMemory(Module):
|
|
304
306
|
|
305
307
|
# the memory is the weights of the model
|
306
308
|
|
309
|
+
if mem_model_norm_add_residual:
|
310
|
+
model = ResidualNorm(dim = dim_head, model = model)
|
311
|
+
|
307
312
|
self.memory_model = model
|
308
313
|
|
309
314
|
mem_model_params = dict(model.named_parameters())
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
+
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=u4DjLU9gnUu5TeT6YEnqeMIuPDfdCyMDocPOHNQBo5U,4887
|
5
|
+
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
+
titans_pytorch-0.3.6.dist-info/METADATA,sha256=CFbbMESkeeScixULPFLoWrbWuIBk_0vHyB95uPKsGdU,6815
|
7
|
+
titans_pytorch-0.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.6.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=0KLHZN-y_7lwrhWSnFRaYJ3GiUV3tzVjxS9CxIx_eI8,4843
|
5
|
-
titans_pytorch/neural_memory.py,sha256=9eyeEvYsP5OFlwLDRyVut99uVYGvXAElFPabVoZnGJw,27063
|
6
|
-
titans_pytorch-0.3.4.dist-info/METADATA,sha256=2ZD_DovSYkVejsTWHq7_IOTN-Je0of1f-HOiojaQBhQ,6815
|
7
|
-
titans_pytorch-0.3.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|