titans-pytorch 0.0.15__tar.gz → 0.0.16__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: titans-pytorch
3
- Version: 0.0.15
3
+ Version: 0.0.16
4
4
  Summary: Titans
5
5
  Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
6
6
  Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "titans-pytorch"
3
- version = "0.0.15"
3
+ version = "0.0.16"
4
4
  description = "Titans"
5
5
  authors = [
6
6
  { name = "Phil Wang", email = "lucidrains@gmail.com" }
@@ -0,0 +1,28 @@
1
+ import torch
2
+ import pytest
3
+
4
+ def test_titans():
5
+ from titans_pytorch import NeuralMemory
6
+
7
+ mem = NeuralMemory(
8
+ dim = 384,
9
+ chunk_size = 64,
10
+ )
11
+
12
+ seq = torch.randn(2, 1024, 384)
13
+ retrieved = mem(seq)
14
+
15
+ assert seq.shape == retrieved.shape
16
+
17
+ def test_titans_attn_memory():
18
+ from titans_pytorch.titans_attn_memory import NeuralMemory
19
+
20
+ mem = NeuralMemory(
21
+ dim = 384,
22
+ chunk_size = 64,
23
+ )
24
+
25
+ seq = torch.randn(2, 1024, 384)
26
+ retrieved = mem(seq)
27
+
28
+ assert seq.shape == retrieved.shape
@@ -1,5 +1,4 @@
1
1
  from titans_pytorch.titans import (
2
2
  NeuralMemory,
3
3
  MemoryMLP,
4
- MemoryAttention
5
4
  )
@@ -80,35 +80,6 @@ class MemoryMLP(Module):
80
80
 
81
81
  return x
82
82
 
83
- # improvised attention as memory module
84
- # todo - expand if see signal in experiments
85
-
86
- class MemoryAttention(Module):
87
- def __init__(
88
- self,
89
- dim
90
- ):
91
- super().__init__()
92
- self.weights = nn.ParameterList([
93
- nn.Parameter(torch.randn(dim, dim)), # queries
94
- nn.Parameter(torch.randn(dim, dim)), # keys
95
- nn.Parameter(torch.randn(dim, dim)), # values
96
- ])
97
-
98
- def forward(self, x):
99
- wq, wk, wv = self.weights
100
-
101
- q = x @ wq
102
- k = x @ wk
103
- v = x @ wv
104
-
105
- sim = q @ k.transpose(-1, -2)
106
-
107
- attn = sim.softmax(dim = -1)
108
-
109
- out = attn @ v
110
- return out
111
-
112
83
  # main neural memory
113
84
 
114
85
  def default_loss_fn(pred, target):
@@ -13,7 +13,11 @@ from local_attention import LocalTransformer
13
13
 
14
14
  from taylor_series_linear_attention import TaylorSeriesLinearAttn
15
15
 
16
- from titans_pytorch.titans import NeuralMemory
16
+ from titans_pytorch.titans import (
17
+ NeuralMemory,
18
+ MemoryAttention,
19
+ MemoryMLP
20
+ )
17
21
 
18
22
  # constants
19
23
 
@@ -1,15 +0,0 @@
1
- import pytest
2
-
3
- import torch
4
- from titans_pytorch import NeuralMemory
5
-
6
- def test_titans():
7
- mem = NeuralMemory(
8
- dim = 384,
9
- chunk_size = 64,
10
- )
11
-
12
- seq = torch.randn(2, 1024, 384)
13
- retrieved = mem(seq)
14
-
15
- assert seq.shape == retrieved.shape
File without changes