titans-pytorch 0.0.15__tar.gz → 0.0.16__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/PKG-INFO +1 -1
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/pyproject.toml +1 -1
- titans_pytorch-0.0.16/tests/test_titans.py +28 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/titans_pytorch/__init__.py +0 -1
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/titans_pytorch/titans.py +0 -29
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/train.py +5 -1
- titans_pytorch-0.0.15/tests/test_titans.py +0 -15
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/.gitignore +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/LICENSE +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/README.md +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/data/README.md +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/fig1.png +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/fig2.png +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/requirements.txt +0 -0
- {titans_pytorch-0.0.15 → titans_pytorch-0.0.16}/titans_pytorch/associative_scan.py +0 -0
@@ -0,0 +1,28 @@
|
|
1
|
+
import torch
|
2
|
+
import pytest
|
3
|
+
|
4
|
+
def test_titans():
|
5
|
+
from titans_pytorch import NeuralMemory
|
6
|
+
|
7
|
+
mem = NeuralMemory(
|
8
|
+
dim = 384,
|
9
|
+
chunk_size = 64,
|
10
|
+
)
|
11
|
+
|
12
|
+
seq = torch.randn(2, 1024, 384)
|
13
|
+
retrieved = mem(seq)
|
14
|
+
|
15
|
+
assert seq.shape == retrieved.shape
|
16
|
+
|
17
|
+
def test_titans_attn_memory():
|
18
|
+
from titans_pytorch.titans_attn_memory import NeuralMemory
|
19
|
+
|
20
|
+
mem = NeuralMemory(
|
21
|
+
dim = 384,
|
22
|
+
chunk_size = 64,
|
23
|
+
)
|
24
|
+
|
25
|
+
seq = torch.randn(2, 1024, 384)
|
26
|
+
retrieved = mem(seq)
|
27
|
+
|
28
|
+
assert seq.shape == retrieved.shape
|
@@ -80,35 +80,6 @@ class MemoryMLP(Module):
|
|
80
80
|
|
81
81
|
return x
|
82
82
|
|
83
|
-
# improvised attention as memory module
|
84
|
-
# todo - expand if see signal in experiments
|
85
|
-
|
86
|
-
class MemoryAttention(Module):
|
87
|
-
def __init__(
|
88
|
-
self,
|
89
|
-
dim
|
90
|
-
):
|
91
|
-
super().__init__()
|
92
|
-
self.weights = nn.ParameterList([
|
93
|
-
nn.Parameter(torch.randn(dim, dim)), # queries
|
94
|
-
nn.Parameter(torch.randn(dim, dim)), # keys
|
95
|
-
nn.Parameter(torch.randn(dim, dim)), # values
|
96
|
-
])
|
97
|
-
|
98
|
-
def forward(self, x):
|
99
|
-
wq, wk, wv = self.weights
|
100
|
-
|
101
|
-
q = x @ wq
|
102
|
-
k = x @ wk
|
103
|
-
v = x @ wv
|
104
|
-
|
105
|
-
sim = q @ k.transpose(-1, -2)
|
106
|
-
|
107
|
-
attn = sim.softmax(dim = -1)
|
108
|
-
|
109
|
-
out = attn @ v
|
110
|
-
return out
|
111
|
-
|
112
83
|
# main neural memory
|
113
84
|
|
114
85
|
def default_loss_fn(pred, target):
|
@@ -13,7 +13,11 @@ from local_attention import LocalTransformer
|
|
13
13
|
|
14
14
|
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
15
15
|
|
16
|
-
from titans_pytorch.titans import
|
16
|
+
from titans_pytorch.titans import (
|
17
|
+
NeuralMemory,
|
18
|
+
MemoryAttention,
|
19
|
+
MemoryMLP
|
20
|
+
)
|
17
21
|
|
18
22
|
# constants
|
19
23
|
|
@@ -1,15 +0,0 @@
|
|
1
|
-
import pytest
|
2
|
-
|
3
|
-
import torch
|
4
|
-
from titans_pytorch import NeuralMemory
|
5
|
-
|
6
|
-
def test_titans():
|
7
|
-
mem = NeuralMemory(
|
8
|
-
dim = 384,
|
9
|
-
chunk_size = 64,
|
10
|
-
)
|
11
|
-
|
12
|
-
seq = torch.randn(2, 1024, 384)
|
13
|
-
retrieved = mem(seq)
|
14
|
-
|
15
|
-
assert seq.shape == retrieved.shape
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|