titans-pytorch 0.3.6__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from titans_pytorch.memory_models import (
6
6
  MemoryMLP,
7
7
  MemoryAttention,
8
8
  FactorizedMemoryMLP,
9
+ MemorySwiGluMLP,
9
10
  GatedResidualMemoryMLP
10
11
  )
11
12
 
@@ -162,6 +162,44 @@ class FactorizedMemoryMLP(Module):
162
162
 
163
163
  return x
164
164
 
165
+ # an MLP modelled after the popular swiglu ff in modern transformers
166
+
167
+ class MemorySwiGluMLP(Module):
168
+ def __init__(
169
+ self,
170
+ dim,
171
+ depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual
172
+ expansion_factor = 4.
173
+ ):
174
+ super().__init__()
175
+
176
+ dim_inner = int(dim * expansion_factor * 2 / 3)
177
+
178
+ weights = []
179
+
180
+ for _ in range(depth):
181
+ weights.append(ParameterList([
182
+ Parameter(torch.randn(dim, dim_inner * 2)),
183
+ Parameter(torch.randn(dim_inner, dim)),
184
+ ]))
185
+
186
+ self.weights = ParameterList(weights)
187
+
188
+ def forward(self, x):
189
+
190
+ for w1, w2 in self.weights:
191
+ residual = x
192
+
193
+ x, gates = (x @ w1).chunk(2, dim = -1)
194
+
195
+ x = x * F.gelu(gates)
196
+
197
+ x = x @ w2
198
+
199
+ x = x + residual
200
+
201
+ return x
202
+
165
203
  # improvised attention as memory module
166
204
 
167
205
  class MemoryAttention(Module):
@@ -175,12 +213,12 @@ class MemoryAttention(Module):
175
213
  self.scale = scale
176
214
  dim_ff_hidden = int(dim * expansion_factor)
177
215
 
178
- self.weights = nn.ParameterList([
179
- nn.Parameter(torch.randn(dim, dim)), # queries
180
- nn.Parameter(torch.randn(dim, dim)), # keys
181
- nn.Parameter(torch.randn(dim, dim)), # values
182
- nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
183
- nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
216
+ self.weights = ParameterList([
217
+ Parameter(torch.randn(dim, dim)), # queries
218
+ Parameter(torch.randn(dim, dim)), # keys
219
+ Parameter(torch.randn(dim, dim)), # values
220
+ Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
221
+ Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
184
222
  ])
185
223
 
186
224
  for weight in self.weights:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.6
3
+ Version: 0.3.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=fC84MZNVfQxWqRx1rZKliqiTLsyT3eIH1iCxGRw8bpI,5789
5
+ titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
+ titans_pytorch-0.3.7.dist-info/METADATA,sha256=hNbS9PgPGNFrNdZwG8nTgExvH1OQTHlOWGQNj4eb0ZU,6815
7
+ titans_pytorch-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.7.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=u4DjLU9gnUu5TeT6YEnqeMIuPDfdCyMDocPOHNQBo5U,4887
5
- titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
- titans_pytorch-0.3.6.dist-info/METADATA,sha256=CFbbMESkeeScixULPFLoWrbWuIBk_0vHyB95uPKsGdU,6815
7
- titans_pytorch-0.3.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.6.dist-info/RECORD,,