titans-pytorch 0.0.51__tar.gz → 0.0.53__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/PKG-INFO +1 -2
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/pyproject.toml +1 -2
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/titans_pytorch/mac_transformer.py +73 -3
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/titans_pytorch/titans.py +3 -3
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/.gitignore +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/LICENSE +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/README.md +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/data/README.md +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/fig1.png +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/fig2.png +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.53}/train_mac.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.53
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -37,7 +37,6 @@ Requires-Python: >=3.9
|
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.5
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
|
-
Requires-Dist: einx>=0.3.0
|
|
41
40
|
Requires-Dist: hyper-connections>=0.1.8
|
|
42
41
|
Requires-Dist: ninja
|
|
43
42
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.53"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -27,7 +27,6 @@ classifiers=[
|
|
|
27
27
|
dependencies = [
|
|
28
28
|
"accelerated-scan>=0.2.0",
|
|
29
29
|
"axial_positional_embedding>=0.3.5",
|
|
30
|
-
"einx>=0.3.0",
|
|
31
30
|
"einops>=0.8.0",
|
|
32
31
|
"hyper-connections>=0.1.8",
|
|
33
32
|
"Ninja",
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
from math import ceil
|
|
3
4
|
from functools import partial
|
|
4
5
|
|
|
@@ -32,7 +33,7 @@ def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
|
32
33
|
|
|
33
34
|
# einstein notation related
|
|
34
35
|
|
|
35
|
-
from einops import
|
|
36
|
+
from einops import repeat, rearrange, pack, unpack
|
|
36
37
|
from einops.layers.torch import Rearrange
|
|
37
38
|
|
|
38
39
|
# b - batch
|
|
@@ -128,6 +129,7 @@ class SegmentedAttention(Module):
|
|
|
128
129
|
heads = 8,
|
|
129
130
|
accept_value_residual = False,
|
|
130
131
|
attend_kwargs: dict = dict(),
|
|
132
|
+
use_flex_attn = False
|
|
131
133
|
):
|
|
132
134
|
super().__init__()
|
|
133
135
|
self.norm = nn.RMSNorm(dim)
|
|
@@ -157,11 +159,79 @@ class SegmentedAttention(Module):
|
|
|
157
159
|
|
|
158
160
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
159
161
|
|
|
162
|
+
# flex attn related
|
|
163
|
+
|
|
164
|
+
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
165
|
+
self.use_flex_attn = use_flex_attn
|
|
166
|
+
|
|
167
|
+
self.segment_len = segment_len
|
|
168
|
+
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
169
|
+
|
|
170
|
+
def forward_flex(
|
|
171
|
+
self,
|
|
172
|
+
seq,
|
|
173
|
+
value_residual = None,
|
|
174
|
+
flex_attn_fn: Callable | None = None
|
|
175
|
+
):
|
|
176
|
+
|
|
177
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
178
|
+
|
|
179
|
+
batch, seq_len = seq.shape[:2]
|
|
180
|
+
|
|
181
|
+
# attention
|
|
182
|
+
|
|
183
|
+
seq = self.norm(seq)
|
|
184
|
+
|
|
185
|
+
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
186
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
187
|
+
|
|
188
|
+
# value residual
|
|
189
|
+
|
|
190
|
+
orig_v = v
|
|
191
|
+
|
|
192
|
+
if exists(self.to_learned_v_mix):
|
|
193
|
+
mix = self.to_learned_v_mix(seq)
|
|
194
|
+
v = v.lerp(value_residual, mix)
|
|
195
|
+
|
|
196
|
+
# take care of persistent memory key / values
|
|
197
|
+
|
|
198
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
|
|
199
|
+
|
|
200
|
+
# relative positions
|
|
201
|
+
|
|
202
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
203
|
+
|
|
204
|
+
# persistent memory
|
|
205
|
+
|
|
206
|
+
k = cat((pmk, k), dim = -2)
|
|
207
|
+
v = cat((pmv, v), dim = -2)
|
|
208
|
+
|
|
209
|
+
# prep flex attention
|
|
210
|
+
|
|
211
|
+
if not exists(flex_attn_fn):
|
|
212
|
+
block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
|
|
213
|
+
|
|
214
|
+
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
215
|
+
|
|
216
|
+
# attention
|
|
217
|
+
|
|
218
|
+
out = flex_attn_fn(q, k, v)
|
|
219
|
+
|
|
220
|
+
out = self.merge_heads(out)
|
|
221
|
+
|
|
222
|
+
out = self.to_out(out)
|
|
223
|
+
|
|
224
|
+
return out, orig_v
|
|
225
|
+
|
|
160
226
|
def forward(
|
|
161
227
|
self,
|
|
162
228
|
seq,
|
|
163
|
-
value_residual = None
|
|
229
|
+
value_residual = None,
|
|
230
|
+
flex_attn_fn: Callable | None = None
|
|
164
231
|
):
|
|
232
|
+
if seq.is_cuda and self.use_flex_attn:
|
|
233
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
234
|
+
|
|
165
235
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
166
236
|
|
|
167
237
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
@@ -191,7 +261,7 @@ class SegmentedAttention(Module):
|
|
|
191
261
|
|
|
192
262
|
# take care of persistent memory key / values
|
|
193
263
|
|
|
194
|
-
pmk, pmv =
|
|
264
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
|
|
195
265
|
|
|
196
266
|
# relative positions
|
|
197
267
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
from typing import Callable
|
|
2
3
|
import math
|
|
3
4
|
from functools import partial
|
|
4
5
|
|
|
@@ -16,7 +17,6 @@ from titans_pytorch.associative_scan import (
|
|
|
16
17
|
pad_at_dim
|
|
17
18
|
)
|
|
18
19
|
|
|
19
|
-
import einx
|
|
20
20
|
from einops import rearrange, repeat, pack, unpack
|
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
|
22
22
|
|
|
@@ -338,9 +338,9 @@ class NeuralMemory(Module):
|
|
|
338
338
|
|
|
339
339
|
# take care of chunking
|
|
340
340
|
|
|
341
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c =
|
|
341
|
+
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
|
|
342
342
|
|
|
343
|
-
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c =
|
|
343
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)
|
|
344
344
|
|
|
345
345
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
346
346
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|