titans-pytorch 0.1.6__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/titans.py +11 -8
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.7.dist-info}/METADATA +1 -1
- titans_pytorch-0.1.7.dist-info/RECORD +8 -0
- titans_pytorch-0.1.6.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.7.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.6.dist-info → titans_pytorch-0.1.7.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/titans.py
CHANGED
|
@@ -161,14 +161,16 @@ class GatedResidualMemoryMLP(Module):
|
|
|
161
161
|
def __init__(
|
|
162
162
|
self,
|
|
163
163
|
dim,
|
|
164
|
-
depth
|
|
164
|
+
depth,
|
|
165
|
+
expansion_factor = 2.
|
|
165
166
|
):
|
|
166
167
|
super().__init__()
|
|
167
|
-
|
|
168
|
+
dim_hidden = int(dim * expansion_factor)
|
|
168
169
|
|
|
169
170
|
self.weights = ParameterList([
|
|
170
171
|
ParameterList([
|
|
171
|
-
Parameter(torch.randn(dim,
|
|
172
|
+
Parameter(torch.randn(dim, dim_hidden)),
|
|
173
|
+
Parameter(torch.randn(dim_hidden, dim)),
|
|
172
174
|
Parameter(torch.randn(dim * 2, dim)),
|
|
173
175
|
]) for _ in range(depth)
|
|
174
176
|
])
|
|
@@ -182,16 +184,17 @@ class GatedResidualMemoryMLP(Module):
|
|
|
182
184
|
self,
|
|
183
185
|
x
|
|
184
186
|
):
|
|
185
|
-
for
|
|
187
|
+
for weight1, weight2, to_gates in self.weights:
|
|
186
188
|
res = x
|
|
187
189
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
+
hidden = x @ weight1
|
|
191
|
+
hidden = F.silu(hidden)
|
|
192
|
+
branch_out = hidden @ weight2
|
|
190
193
|
|
|
191
194
|
# gated residual
|
|
192
195
|
|
|
193
|
-
gates = cat((
|
|
194
|
-
x = res.lerp(
|
|
196
|
+
gates = cat((branch_out, res), dim = -1) @ to_gates
|
|
197
|
+
x = res.lerp(branch_out, gates.sigmoid())
|
|
195
198
|
|
|
196
199
|
return x @ self.final_proj
|
|
197
200
|
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
+
titans_pytorch/titans.py,sha256=GFnORd9WRQBBavEUWz0DV0lxnHtLsrDEghX0RgWhZaQ,20758
|
|
5
|
+
titans_pytorch-0.1.7.dist-info/METADATA,sha256=ltCnRZ1MXbqxJ5-L41Z1pvfTDpLJNnk_XZ9hpfX6wrQ,4747
|
|
6
|
+
titans_pytorch-0.1.7.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.7.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.7.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
-
titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
|
|
5
|
-
titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
|
|
6
|
-
titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|