titans-pytorch 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- titans_pytorch/__init__.py +0 -1
- titans_pytorch/titans.py +0 -29
- {titans_pytorch-0.0.15.dist-info → titans_pytorch-0.0.16.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.16.dist-info/RECORD +7 -0
- titans_pytorch-0.0.15.dist-info/RECORD +0 -7
- {titans_pytorch-0.0.15.dist-info → titans_pytorch-0.0.16.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.15.dist-info → titans_pytorch-0.0.16.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/titans.py
CHANGED
@@ -80,35 +80,6 @@ class MemoryMLP(Module):
|
|
80
80
|
|
81
81
|
return x
|
82
82
|
|
83
|
-
# improvised attention as memory module
|
84
|
-
# todo - expand if see signal in experiments
|
85
|
-
|
86
|
-
class MemoryAttention(Module):
|
87
|
-
def __init__(
|
88
|
-
self,
|
89
|
-
dim
|
90
|
-
):
|
91
|
-
super().__init__()
|
92
|
-
self.weights = nn.ParameterList([
|
93
|
-
nn.Parameter(torch.randn(dim, dim)), # queries
|
94
|
-
nn.Parameter(torch.randn(dim, dim)), # keys
|
95
|
-
nn.Parameter(torch.randn(dim, dim)), # values
|
96
|
-
])
|
97
|
-
|
98
|
-
def forward(self, x):
|
99
|
-
wq, wk, wv = self.weights
|
100
|
-
|
101
|
-
q = x @ wq
|
102
|
-
k = x @ wk
|
103
|
-
v = x @ wv
|
104
|
-
|
105
|
-
sim = q @ k.transpose(-1, -2)
|
106
|
-
|
107
|
-
attn = sim.softmax(dim = -1)
|
108
|
-
|
109
|
-
out = attn @ v
|
110
|
-
return out
|
111
|
-
|
112
83
|
# main neural memory
|
113
84
|
|
114
85
|
def default_loss_fn(pred, target):
|
@@ -0,0 +1,7 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
+
titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
|
4
|
+
titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
|
5
|
+
titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
+
titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
+
titans_pytorch-0.0.16.dist-info/RECORD,,
|
@@ -1,7 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=bu8p8kUA24EVrTz-ojixHTwV-6KTY9Y0cNJkaMW4Whw,91
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
3
|
-
titans_pytorch/titans.py,sha256=ArFYKgI0p7N3mmv8b4ncxkl3gkKAXWrFrnl2quh2RqE,12930
|
4
|
-
titans_pytorch-0.0.15.dist-info/METADATA,sha256=3IC7BT7J3BYx23wUOYeuGgtz769dsIczBnWbm6oi0Tw,3811
|
5
|
-
titans_pytorch-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
6
|
-
titans_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
7
|
-
titans_pytorch-0.0.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|