titans-pytorch 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ from titans_pytorch.memory_models import (
6
6
  MemoryMLP,
7
7
  MemoryAttention,
8
8
  FactorizedMemoryMLP,
9
+ MemorySwiGluMLP,
9
10
  GatedResidualMemoryMLP
10
11
  )
11
12
 
@@ -103,8 +103,6 @@ class GatedResidualMemoryMLP(Module):
103
103
 
104
104
  self.final_proj = Parameter(torch.randn(dim, dim))
105
105
 
106
- self.ln = LayerNorm(dim)
107
-
108
106
  for param in self.parameters():
109
107
  nn.init.xavier_uniform_(param)
110
108
 
@@ -145,8 +143,6 @@ class FactorizedMemoryMLP(Module):
145
143
  ]) for _ in range(depth)
146
144
  ])
147
145
 
148
- self.ln = LayerNorm(dim)
149
-
150
146
  for weight1, weight2 in self.weights:
151
147
  nn.init.xavier_uniform_(weight1)
152
148
  nn.init.xavier_uniform_(weight2)
@@ -166,6 +162,44 @@ class FactorizedMemoryMLP(Module):
166
162
 
167
163
  return x
168
164
 
165
+ # an MLP modelled after the popular swiglu ff in modern transformers
166
+
167
+ class MemorySwiGluMLP(Module):
168
+ def __init__(
169
+ self,
170
+ dim,
171
+ depth = 1, # default to 2 layer MLP from TTT, depth of 2 would be 4 layer MLP, but done as 2 feedforwards with residual
172
+ expansion_factor = 4.
173
+ ):
174
+ super().__init__()
175
+
176
+ dim_inner = int(dim * expansion_factor * 2 / 3)
177
+
178
+ weights = []
179
+
180
+ for _ in range(depth):
181
+ weights.append(ParameterList([
182
+ Parameter(torch.randn(dim, dim_inner * 2)),
183
+ Parameter(torch.randn(dim_inner, dim)),
184
+ ]))
185
+
186
+ self.weights = ParameterList(weights)
187
+
188
+ def forward(self, x):
189
+
190
+ for w1, w2 in self.weights:
191
+ residual = x
192
+
193
+ x, gates = (x @ w1).chunk(2, dim = -1)
194
+
195
+ x = x * F.gelu(gates)
196
+
197
+ x = x @ w2
198
+
199
+ x = x + residual
200
+
201
+ return x
202
+
169
203
  # improvised attention as memory module
170
204
 
171
205
  class MemoryAttention(Module):
@@ -179,16 +213,14 @@ class MemoryAttention(Module):
179
213
  self.scale = scale
180
214
  dim_ff_hidden = int(dim * expansion_factor)
181
215
 
182
- self.weights = nn.ParameterList([
183
- nn.Parameter(torch.randn(dim, dim)), # queries
184
- nn.Parameter(torch.randn(dim, dim)), # keys
185
- nn.Parameter(torch.randn(dim, dim)), # values
186
- nn.Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
187
- nn.Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
216
+ self.weights = ParameterList([
217
+ Parameter(torch.randn(dim, dim)), # queries
218
+ Parameter(torch.randn(dim, dim)), # keys
219
+ Parameter(torch.randn(dim, dim)), # values
220
+ Parameter(torch.randn(dim, dim_ff_hidden)), # ff w1
221
+ Parameter(torch.randn(dim_ff_hidden, dim)), # ff w2
188
222
  ])
189
223
 
190
- self.ln = LayerNorm(dim)
191
-
192
224
  for weight in self.weights:
193
225
  nn.init.xavier_uniform_(weight)
194
226
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.3.5
3
+ Version: 0.3.7
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,9 @@
1
+ titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
2
+ titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
+ titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
+ titans_pytorch/memory_models.py,sha256=fC84MZNVfQxWqRx1rZKliqiTLsyT3eIH1iCxGRw8bpI,5789
5
+ titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
+ titans_pytorch-0.3.7.dist-info/METADATA,sha256=hNbS9PgPGNFrNdZwG8nTgExvH1OQTHlOWGQNj4eb0ZU,6815
7
+ titans_pytorch-0.3.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
+ titans_pytorch-0.3.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
+ titans_pytorch-0.3.7.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=Y3m_ZlpEqYwp-Md1ARhNGJxq8bQp8ty1o039nZOOJo0,276
2
- titans_pytorch/associative_scan.py,sha256=CEPXaZ2fEPWF8ZBe5wihCqPSGi8PNyL0uVSgvY7eV-s,5147
3
- titans_pytorch/mac_transformer.py,sha256=5rO4GQxSyFWWEc3pc3xNyG0sK5EXE7MmxKI-_kEMl2M,24941
4
- titans_pytorch/memory_models.py,sha256=2fma9u0NQmDabgbpG6CLDGBRYzX99yIDQCSYIB0etkU,4989
5
- titans_pytorch/neural_memory.py,sha256=2Ffq5fob6_vJVUs1jIai3BJhfsfypKffJEb0QwRRdMk,27325
6
- titans_pytorch-0.3.5.dist-info/METADATA,sha256=R6EL4q-zgW7DV5OyLzqz5XP2IvLNpJkaBylwH8GsyII,6815
7
- titans_pytorch-0.3.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- titans_pytorch-0.3.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
9
- titans_pytorch-0.3.5.dist-info/RECORD,,