titans-pytorch 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +1 -0
- titans_pytorch/memory_models.py +44 -6
- {titans_pytorch-0.3.6.dist-info → titans_pytorch-0.3.7.dist-info}/METADATA +1 -1
- titans_pytorch-0.3.7.dist-info/RECORD +9 -0
- titans_pytorch-0.3.6.dist-info/RECORD +0 -9
- {titans_pytorch-0.3.6.dist-info → titans_pytorch-0.3.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.3.6.dist-info → titans_pytorch-0.3.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/memory_models.py
CHANGED
@@ -162,6 +162,44 @@ class FactorizedMemoryMLP(Module):
|
|
162
162
|
|
163
163
|
return x
|
164
164
|
|
165
|
+
# an MLP modelled after the popular swiglu ff in modern transformers
|
166
|
+
|
167
|
+
class MemorySwiGluMLP(Module):
|
168
|
+
def __init__(
|
169
|
+
self,
|
170
|
+
dim,
|
171
|
+
depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual
|
172
|
+
expansion_factor = 4.
|
173
|
+
):
|
174
|
+
super().__init__()
|
175
|
+
|
176
|
+
dim_inner = int(dim * expansion_factor * 2 / 3)
|
177
|
+
|
178
|
+
weights = []
|
179
|
+
|
180
|
+
for _ in range(depth):
|
181
|
+
weights.append(ParameterList([
|
182
|
+
Parameter(torch.randn(dim, dim_inner * 2)),
|
183
|
+
Parameter(torch.randn(dim_inner, dim)),
|
184
|
+
]))
|
185
|
+
|
186
|
+
self.weights = ParameterList(weights)
|
187
|
+
|
188
|
+
def forward(self, x):
|
189
|
+
|
190
|
+
for w1, w2 in self.weights:
|
191
|
+
residual = x
|
192
|
+
|
193
|
+
x, gates = (x @ w1).chunk(2, dim = -1)
|
194
|
+
|
195
|
+
x = x * F.gelu(gates)
|
196
|
+
|
197
|
+
x = x @ w2
|
198
|
+
|
199
|
+
x = x + residual
|
200
|
+
|
201
|
+
return x
|
202
|
+
|
165
203
|
# improvised attention as memory module
|
166
204
|
|
167
205
|
class MemoryAttention(Module):
|
@@ -175,12 +213,12 @@ class MemoryAttention(Module):
|
|
175
213
|
self.scale = scale
|
176
214
|
dim_ff_hidden = int(dim * expansion_factor)
|
177
215
|
|
178
|
-
self.weights =
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
216
|
+
self.weights = ParameterList([
|
217
|
+
Parameter(torch.randn(dim, dim)), # queries
|
218
|
+
Parameter(torch.randn(dim, dim)), # keys
|
219
|
+
Parameter(torch.randn(dim, dim)), # values
|
220
|
+
Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
|
221
|
+
Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
|
184
222
|
])
|
185
223
|
|
186
224
|
for weight in self.weights:
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
+
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
+
titans_pytorch/memory_models.py,sha256=fC84MZNVfQxWqRx1rZKliqiTLsyT3eIH1iCxGRw8bpI,5789
|
5
|
+
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
+
titans_pytorch-0.3.7.dist-info/METADATA,sha256=hNbS9PgPGNFrNdZwG8nTgExvH1OQTHlOWGQNj4eb0ZU,6815
|
7
|
+
titans_pytorch-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.3.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.3.7.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
|
2
|
-
titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
|
4
|
-
titans_pytorch/memory_models.py,sha256=u4DjLU9gnUu5TeT6YEnqeMIuPDfdCyMDocPOHNQBo5U,4887
|
5
|
-
titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
|
6
|
-
titans_pytorch-0.3.6.dist-info/METADATA,sha256=CFbbMESkeeScixULPFLoWrbWuIBk_0vHyB95uPKsGdU,6815
|
7
|
-
titans_pytorch-0.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.3.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.3.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|