titans-pytorch 0.0.26__tar.gz → 0.0.27__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/PKG-INFO +1 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/pyproject.toml +1 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/titans_pytorch/mac_transformer.py +23 -2
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/.gitignore +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/LICENSE +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/data/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/fig1.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/fig2.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/requirements.txt +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.27}/train.py +0 -0
@@ -3,10 +3,11 @@ import math
|
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn
|
6
|
+
from torch import nn, cat
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
+
from einops import repeat
|
10
11
|
from einops.layers.torch import Rearrange
|
11
12
|
|
12
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -48,6 +49,7 @@ class SegmentedAttention(Module):
|
|
48
49
|
self,
|
49
50
|
dim,
|
50
51
|
segment_len,
|
52
|
+
num_persist_mem_tokens,
|
51
53
|
dim_head = 64,
|
52
54
|
heads = 8,
|
53
55
|
):
|
@@ -67,6 +69,7 @@ class SegmentedAttention(Module):
|
|
67
69
|
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
68
70
|
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
69
71
|
|
72
|
+
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
70
73
|
|
71
74
|
def forward(self, seq):
|
72
75
|
batch, seq_len = seq.shape[:2]
|
@@ -92,6 +95,15 @@ class SegmentedAttention(Module):
|
|
92
95
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
93
96
|
q, k, v = map(self.split_heads, (q, k, v))
|
94
97
|
|
98
|
+
# take care of persistent memory key / values
|
99
|
+
|
100
|
+
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
101
|
+
|
102
|
+
k = cat((pmk, k), dim = -2)
|
103
|
+
v = cat((pmv, v), dim = -2)
|
104
|
+
|
105
|
+
# sdpa
|
106
|
+
|
95
107
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
96
108
|
|
97
109
|
out = self.merge_heads(out)
|
@@ -113,6 +125,7 @@ class MemoryAsContextTransformer(Module):
|
|
113
125
|
dim,
|
114
126
|
depth,
|
115
127
|
segment_len,
|
128
|
+
num_persist_mem_tokens,
|
116
129
|
dim_head = 64,
|
117
130
|
heads = 8,
|
118
131
|
ff_mult = 4,
|
@@ -127,7 +140,14 @@ class MemoryAsContextTransformer(Module):
|
|
127
140
|
self.layers = ModuleList([])
|
128
141
|
|
129
142
|
for _ in range(depth):
|
130
|
-
attn = SegmentedAttention(
|
143
|
+
attn = SegmentedAttention(
|
144
|
+
dim = dim,
|
145
|
+
dim_head = dim_head,
|
146
|
+
heads = heads,
|
147
|
+
segment_len = segment_len,
|
148
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
149
|
+
)
|
150
|
+
|
131
151
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
132
152
|
|
133
153
|
self.layers.append(ModuleList([
|
@@ -162,6 +182,7 @@ if __name__ == '__main__':
|
|
162
182
|
num_tokens = 256,
|
163
183
|
dim = 256,
|
164
184
|
depth = 2,
|
185
|
+
num_persist_mem_tokens = 16,
|
165
186
|
segment_len = 128,
|
166
187
|
)
|
167
188
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|