titans-pytorch 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -0
- titans_pytorch/memory_models.py +44 -12
- {titans_pytorch-0.3.5.dist-info → titans_pytorch-0.3.7.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.7.dist-info/RECORD +9 -0
- titans_pytorch-0.3.5.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.5.dist-info → titans_pytorch-0.3.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.5.dist-info → titans_pytorch-0.3.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/memory_models.py
CHANGED
@@ -103,8 +103,6 @@ class GatedResidualMemoryMLP(Module):
|
|
103
103
|
|
104
104
|
self.final_proj = Parameter(torch.randn(dim, dim))
|
105
105
|
|
106
|
-
self.ln = LayerNorm(dim)
|
107
|
-
|
108
106
|
for param in self.parameters():
|
109
107
|
nn.init.xavier_uniform_(param)
|
110
108
|
|
@@ -145,8 +143,6 @@ class FactorizedMemoryMLP(Module):
|
|
145
143
|
]) for _ in range(depth)
|
146
144
|
])
|
147
145
|
|
148
|
-
self.ln = LayerNorm(dim)
|
149
|
-
|
150
146
|
for weight1, weight2 in self.weights:
|
151
147
|
nn.init.xavier_uniform_(weight1)
|
152
148
|
nn.init.xavier_uniform_(weight2)
|
@@ -166,6 +162,44 @@ class FactorizedMemoryMLP(Module):
|
|
166
162
|
|
167
163
|
return x
|
168
164
|
|
165
|
+
# an MLP modelled after the popular swiglu ff in modern transformers
|
166
|
+
|
167
|
+
class MemorySwiGluMLP(Module):
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
dim,
|
171
|
+
depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual
|
172
|
+
expansion_factor = 4.
|
173
|
+
):
|
174
|
+
super().__init__()
|
175
|
+
|
176
|
+
dim_inner = int(dim * expansion_factor * 2 / 3)
|
177
|
+
|
178
|
+
weights = []
|
179
|
+
|
180
|
+
for _ in range(depth):
|
181
|
+
weights.append(ParameterList([
|
182
|
+
Parameter(torch.randn(dim, dim_inner * 2)),
|
183
|
+
Parameter(torch.randn(dim_inner, dim)),
|
184
|
+
]))
|
185
|
+
|
186
|
+
self.weights = ParameterList(weights)
|
187
|
+
|
188
|
+
def forward(self, x):
|
189
|
+
|
190
|
+
for w1, w2 in self.weights:
|
191
|
+
residual = x
|
192
|
+
|
193
|
+
x, gates = (x @ w1).chunk(2, dim = -1)
|
194
|
+
|
195
|
+
x = x * F.gelu(gates)
|
196
|
+
|
197
|
+
x = x @ w2
|
198
|
+
|
199
|
+
x = x + residual
|
200
|
+
|
201
|
+
return x
|
202
|
+
|
169
203
|
# improvised attention as memory module
|
170
204
|
|
171
205
|
class MemoryAttention(Module):
|
@@ -179,16 +213,14 @@ class MemoryAttention(Module):
|
|
179
213
|
self.scale = scale
|
180
214
|
dim_ff_hidden = int(dim * expansion_factor)
|
181
215
|
|
182
|
-
self.weights =
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
216
|
+
self.weights = ParameterList([
|
217
|
+
Parameter(torch.randn(dim, dim)), # queries
|
218
|
+
Parameter(torch.randn(dim, dim)), # keys
|
219
|
+
Parameter(torch.randn(dim, dim)), # values
|
220
|
+
Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
221
|
+
Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
188
222
|
])
|
189
223
|
|
190
|
-
self.ln = LayerNorm(dim)
|
191
|
-
|
192
224
|
for weight in self.weights:
|
193
225
|
nn.init.xavier_uniform_(weight)
|
194
226
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=fC84MZNVfQxWqRx1rZKliqiTLsyT3eIH1iCxGRw8bpI,5789
|
5
|
+
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
+
titans_pytorch-0.3.7.dist-info/METADATA,sha256=hNbS9PgPGNFrNdZwG8nTgExvH1OQTHlOWGQNj4eb0ZU,6815
|
7
|
+
titans_pytorch-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.7.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=2fma9u0NQmDabgbpG6CLDGBRYzX99yIDQCSYIB0etkU,4989
|
5
|
-
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
-
titans_pytorch-0.3.5.dist-info/METADATA,sha256=R6EL4q-zgW7DV5OyLzqz5XP2IvLNpJkaBylwH8GsyII,6815
|
7
|
-
titans_pytorch-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|