titans-pytorch 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -1
- titans_pytorch/titans.py +42 -1
- {titans_pytorch-0.1.5.dist-info → titans_pytorch-0.1.6.dist-info}/METADATA +2 -1
- titans_pytorch-0.1.6.dist-info/RECORD +8 -0
- titans_pytorch-0.1.5.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.5.dist-info → titans_pytorch-0.1.6.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.5.dist-info → titans_pytorch-0.1.6.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
import math
|
|
4
5
|
from functools import partial
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from torch import nn, Tensor
|
|
8
|
+
from torch import nn, cat, Tensor
|
|
8
9
|
import torch.nn.functional as F
|
|
9
10
|
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
11
|
from torch.func import functional_call, vmap, grad
|
|
@@ -154,6 +155,46 @@ class MemoryMLP(Module):
|
|
|
154
155
|
|
|
155
156
|
return x
|
|
156
157
|
|
|
158
|
+
# memory mlp, but with gated residual + final projection
|
|
159
|
+
|
|
160
|
+
class GatedResidualMemoryMLP(Module):
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
dim,
|
|
164
|
+
depth
|
|
165
|
+
):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.depth = depth
|
|
168
|
+
|
|
169
|
+
self.weights = ParameterList([
|
|
170
|
+
ParameterList([
|
|
171
|
+
Parameter(torch.randn(dim, dim)),
|
|
172
|
+
Parameter(torch.randn(dim * 2, dim)),
|
|
173
|
+
]) for _ in range(depth)
|
|
174
|
+
])
|
|
175
|
+
|
|
176
|
+
self.final_proj = Parameter(torch.randn(dim, dim))
|
|
177
|
+
|
|
178
|
+
for param in self.parameters():
|
|
179
|
+
nn.init.xavier_uniform_(param)
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self,
|
|
183
|
+
x
|
|
184
|
+
):
|
|
185
|
+
for weight, to_gates in self.weights:
|
|
186
|
+
res = x
|
|
187
|
+
|
|
188
|
+
x = x @ weight
|
|
189
|
+
x = F.silu(x)
|
|
190
|
+
|
|
191
|
+
# gated residual
|
|
192
|
+
|
|
193
|
+
gates = cat((x, res), dim = -1) @ to_gates
|
|
194
|
+
x = res.lerp(x, gates.sigmoid())
|
|
195
|
+
|
|
196
|
+
return x @ self.final_proj
|
|
197
|
+
|
|
157
198
|
# memory mlp with factorized weights
|
|
158
199
|
# so can tradeoff capacity for smaller chunk sizes
|
|
159
200
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
|
|
|
46
46
|
Requires-Dist: tqdm
|
|
47
47
|
Requires-Dist: x-transformers
|
|
48
48
|
Provides-Extra: examples
|
|
49
|
+
Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
|
|
49
50
|
Requires-Dist: wandb; extra == 'examples'
|
|
50
51
|
Provides-Extra: test
|
|
51
52
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
+
titans_pytorch/titans.py,sha256=VMcPcKsoR3G13Um62Aa1HbdwrrV60ljPhP-yF40x90I,20555
|
|
5
|
+
titans_pytorch-0.1.6.dist-info/METADATA,sha256=LJW26WfT9WB-0NfokLLHhcRpWnt76jwkXMt_FSTI3SM,4747
|
|
6
|
+
titans_pytorch-0.1.6.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.6.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.6.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=YowJzQC2p3fWgzFCe9jXrw1b3wWPKN-dcLs5sX-3r8Y,19123
|
|
4
|
-
titans_pytorch/titans.py,sha256=iF0tTTyLs3hPhJDvGVKD2PdXgpWo9xOggD_42szPwjg,19632
|
|
5
|
-
titans_pytorch-0.1.5.dist-info/METADATA,sha256=GrCMbvIDT9gdL8JJ-U55oxFeB8TVRI2PTuvFK2QQjbk,4684
|
|
6
|
-
titans_pytorch-0.1.5.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.5.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|