titans-pytorch 0.0.29__tar.gz → 0.0.30__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/PKG-INFO +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/pyproject.toml +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/tests/test_titans.py +8 -2
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/mac_transformer.py +57 -27
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/.gitignore +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/LICENSE +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/data/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/fig1.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/fig2.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/requirements.txt +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.30}/train.py +0 -0
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
|
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
36
|
|
37
|
-
|
37
|
+
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
38
|
+
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
39
|
+
def test_mac(
|
40
|
+
num_persist_mem_tokens,
|
41
|
+
num_longterm_mem_tokens
|
42
|
+
):
|
38
43
|
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
44
|
|
40
45
|
transformer = MemoryAsContextTransformer(
|
41
46
|
num_tokens = 256,
|
42
47
|
dim = 256,
|
43
48
|
depth = 2,
|
44
|
-
num_persist_mem_tokens =
|
49
|
+
num_persist_mem_tokens = num_persist_mem_tokens,
|
50
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
45
51
|
segment_len = 128,
|
46
52
|
)
|
47
53
|
|
@@ -7,10 +7,9 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
-
|
14
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
15
14
|
|
16
15
|
# absolute and relative positions
|
@@ -55,7 +54,8 @@ class SegmentedAttention(Module):
|
|
55
54
|
self,
|
56
55
|
dim,
|
57
56
|
segment_len,
|
58
|
-
num_persist_mem_tokens,
|
57
|
+
num_persist_mem_tokens = 0,
|
58
|
+
num_longterm_mem_tokens = 0,
|
59
59
|
dim_head = 64,
|
60
60
|
heads = 8,
|
61
61
|
):
|
@@ -70,25 +70,31 @@ class SegmentedAttention(Module):
|
|
70
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
71
71
|
|
72
72
|
self.segment_len = segment_len
|
73
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
74
|
+
|
75
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
73
76
|
|
74
77
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
75
78
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
76
79
|
|
77
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n =
|
78
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n =
|
80
|
+
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = total_segment_len)
|
81
|
+
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = total_segment_len)
|
79
82
|
|
80
83
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
81
84
|
|
82
85
|
def forward(self, seq):
|
86
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
87
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
88
|
+
|
83
89
|
batch, seq_len = seq.shape[:2]
|
84
90
|
|
85
91
|
# auto pad to multiple
|
86
92
|
# todo - get rid of logic with flex attention
|
87
93
|
|
88
|
-
need_segment = seq_len >=
|
94
|
+
need_segment = seq_len >= total_segment_len
|
89
95
|
|
90
96
|
if need_segment:
|
91
|
-
next_seq_len = round_up_multiple(seq_len,
|
97
|
+
next_seq_len = round_up_multiple(seq_len, total_segment_len)
|
92
98
|
padding = next_seq_len - seq_len
|
93
99
|
|
94
100
|
if padding > 0:
|
@@ -139,7 +145,8 @@ class MemoryAsContextTransformer(Module):
|
|
139
145
|
dim,
|
140
146
|
depth,
|
141
147
|
segment_len,
|
142
|
-
|
148
|
+
num_longterm_mem_tokens = 0,
|
149
|
+
num_persist_mem_tokens = 0,
|
143
150
|
dim_head = 64,
|
144
151
|
heads = 8,
|
145
152
|
ff_mult = 4,
|
@@ -147,10 +154,18 @@ class MemoryAsContextTransformer(Module):
|
|
147
154
|
):
|
148
155
|
super().__init__()
|
149
156
|
|
150
|
-
self.
|
157
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
158
|
+
|
151
159
|
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
160
|
|
153
|
-
|
161
|
+
# long term mem tokens
|
162
|
+
|
163
|
+
self.segment_len = segment_len
|
164
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
165
|
+
|
166
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
167
|
+
|
168
|
+
# hyper conection
|
154
169
|
|
155
170
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
156
171
|
|
@@ -162,6 +177,7 @@ class MemoryAsContextTransformer(Module):
|
|
162
177
|
dim_head = dim_head,
|
163
178
|
heads = heads,
|
164
179
|
segment_len = segment_len,
|
180
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
165
181
|
num_persist_mem_tokens = num_persist_mem_tokens
|
166
182
|
)
|
167
183
|
|
@@ -177,15 +193,42 @@ class MemoryAsContextTransformer(Module):
|
|
177
193
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
178
194
|
|
179
195
|
def forward(self, x):
|
180
|
-
|
196
|
+
|
197
|
+
# math
|
198
|
+
|
199
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
200
|
+
|
181
201
|
windows = ceil(seq_len / segment_len)
|
202
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
203
|
+
|
204
|
+
# token embedding
|
182
205
|
|
183
206
|
x = self.token_emb(x)
|
184
207
|
|
208
|
+
# intersperse longterm memory
|
209
|
+
|
210
|
+
need_segment = seq_len >= segment_len
|
211
|
+
|
212
|
+
if need_segment:
|
213
|
+
next_seq_len = round_up_multiple(seq_len, segment_len)
|
214
|
+
padding = next_seq_len - seq_len
|
215
|
+
|
216
|
+
if padding > 0:
|
217
|
+
x = F.pad(x, (0, 0, 0, padding))
|
218
|
+
|
219
|
+
x = rearrange(x, 'b (w n) d -> (b w) n d', n = segment_len)
|
220
|
+
|
221
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
222
|
+
x = torch.cat((mems, x), dim = -2)
|
223
|
+
|
224
|
+
if need_segment:
|
225
|
+
x = rearrange(x, '(b w) n d -> b (w n) d', b = batch)
|
226
|
+
x = x[:, :seq_len]
|
227
|
+
|
185
228
|
# apply axial positional embedding
|
186
229
|
# so intra and inter segment can be more easily discerned by the network
|
187
230
|
|
188
|
-
pos_emb = self.axial_pos_emb((windows,
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
189
232
|
x = x + pos_emb[:seq_len]
|
190
233
|
|
191
234
|
# expand and reduce streams for hyper connections
|
@@ -198,21 +241,8 @@ class MemoryAsContextTransformer(Module):
|
|
198
241
|
|
199
242
|
x = self.reduce_streams(x)
|
200
243
|
|
244
|
+
# to logits
|
245
|
+
|
201
246
|
x = self.norm(x)
|
202
247
|
|
203
248
|
return self.to_logits(x)
|
204
|
-
|
205
|
-
# main
|
206
|
-
|
207
|
-
if __name__ == '__main__':
|
208
|
-
transformer = MemoryAsContextTransformer(
|
209
|
-
num_tokens = 256,
|
210
|
-
dim = 256,
|
211
|
-
depth = 2,
|
212
|
-
num_persist_mem_tokens = 16,
|
213
|
-
segment_len = 128,
|
214
|
-
)
|
215
|
-
|
216
|
-
x = torch.randint(0, 256, (1, 1023))
|
217
|
-
|
218
|
-
logits = transformer(x)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|