titans-pytorch 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,5 +1,4 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
- MemoryAttention
5
4
  )
titans_pytorch/titans.py CHANGED
@@ -80,35 +80,6 @@ class MemoryMLP(Module):
80
80
 
81
81
  return x
82
82
 
83
- # improvised attention as memory module
84
- # todo - expand if see signal in experiments
85
-
86
- class MemoryAttention(Module):
87
- def __init__(
88
- self,
89
- dim
90
- ):
91
- super().__init__()
92
- self.weights = nn.ParameterList([
93
- nn.Parameter(torch.randn(dim, dim)), # queries
94
- nn.Parameter(torch.randn(dim, dim)), # keys
95
- nn.Parameter(torch.randn(dim, dim)), # values
96
- ])
97
-
98
- def forward(self, x):
99
- wq, wk, wv = self.weights
100
-
101
- q = x @ wq
102
- k = x @ wk
103
- v = x @ wv
104
-
105
- sim = q @ k.transpose(-1, -2)
106
-
107
- attn = sim.softmax(dim = -1)
108
-
109
- out = attn @ v
110
- return out
111
-
112
83
  # main neural memory
113
84
 
114
85
  def default_loss_fn(pred, target):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -0,0 +1,7 @@
1
+ titans_pytorch/__init__.py,sha256=nB0873FZ_OyCda3qFeWTdpO4LbbrXDEvtAefVLzh6g0,71
2
+ titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
+ titans_pytorch/titans.py,sha256=HYm0R_1w3s8MNPsyE2qAVpHGqTBX_AoWtjzxRfF1Ams,12269
4
+ titans_pytorch-0.0.16.dist-info/METADATA,sha256=mzSgA4okWWSh97fncN2UKJaTVa3PWOHYVaFkQ1Ker0w,3811
5
+ titans_pytorch-0.0.16.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
+ titans_pytorch-0.0.16.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
+ titans_pytorch-0.0.16.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- titans_pytorch/__init__.py,sha256=bu8p8kUA24EVrTz-ojixHTwV-6KTY9Y0cNJkaMW4Whw,91
2
- titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
3
- titans_pytorch/titans.py,sha256=ArFYKgI0p7N3mmv8b4ncxkl3gkKAXWrFrnl2quh2RqE,12930
4
- titans_pytorch-0.0.15.dist-info/METADATA,sha256=3IC7BT7J3BYx23wUOYeuGgtz769dsIczBnWbm6oi0Tw,3811
5
- titans_pytorch-0.0.15.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
6
- titans_pytorch-0.0.15.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
7
- titans_pytorch-0.0.15.dist-info/RECORD,,