titans-pytorch 0.0.37__tar.gz → 0.0.39__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/PKG-INFO +2 -1
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/pyproject.toml +2 -1
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/mac_transformer.py +10 -5
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/titans.py +8 -5
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/.gitignore +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/LICENSE +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/README.md +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/data/README.md +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/fig1.png +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/fig2.png +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/requirements.txt +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/train.py +0 -0
- {titans_pytorch-0.0.37 → titans_pytorch-0.0.39}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.39
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -43,6 +43,7 @@ Requires-Dist: ninja
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
44
44
|
Requires-Dist: tensordict
|
45
45
|
Requires-Dist: torch>=2.2
|
46
|
+
Requires-Dist: x-transformers
|
46
47
|
Provides-Extra: examples
|
47
48
|
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
48
49
|
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "titans-pytorch"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.39"
|
4
4
|
description = "Titans"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -34,6 +34,7 @@ dependencies = [
|
|
34
34
|
"rotary-embedding-torch",
|
35
35
|
"tensordict",
|
36
36
|
"torch>=2.2",
|
37
|
+
"x-transformers"
|
37
38
|
]
|
38
39
|
|
39
40
|
[project.urls]
|
@@ -7,7 +7,7 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat, rearrange, pack, unpack
|
10
|
+
from einops import einsum, repeat, rearrange, pack, unpack
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
@@ -16,6 +16,7 @@ from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
16
16
|
|
17
17
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
18
18
|
from rotary_embedding_torch import RotaryEmbedding
|
19
|
+
from x_transformers.attend import Attend
|
19
20
|
|
20
21
|
# proposed neural memory
|
21
22
|
|
@@ -93,6 +94,7 @@ class SegmentedAttention(Module):
|
|
93
94
|
num_longterm_mem_tokens = 0,
|
94
95
|
dim_head = 64,
|
95
96
|
heads = 8,
|
97
|
+
attend_kwargs: dict = dict()
|
96
98
|
):
|
97
99
|
super().__init__()
|
98
100
|
self.norm = nn.RMSNorm(dim)
|
@@ -101,6 +103,8 @@ class SegmentedAttention(Module):
|
|
101
103
|
|
102
104
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
103
105
|
|
106
|
+
self.attend = Attend(causal = True, **attend_kwargs)
|
107
|
+
|
104
108
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
105
109
|
self.to_out = LinearNoBias(dim_inner, dim)
|
106
110
|
|
@@ -145,9 +149,9 @@ class SegmentedAttention(Module):
|
|
145
149
|
k = cat((pmk, k), dim = -2)
|
146
150
|
v = cat((pmv, v), dim = -2)
|
147
151
|
|
148
|
-
#
|
152
|
+
# attention
|
149
153
|
|
150
|
-
out =
|
154
|
+
out, _ = self.attend(q, k, v)
|
151
155
|
|
152
156
|
out = self.merge_heads(out)
|
153
157
|
|
@@ -288,7 +292,8 @@ class MemoryAsContextTransformer(Module):
|
|
288
292
|
for (attn, ff), maybe_neural_mem in zip(self.layers, self.neural_mem_layers):
|
289
293
|
|
290
294
|
if exists(maybe_neural_mem):
|
291
|
-
|
295
|
+
x = maybe_neural_mem(x)
|
296
|
+
|
292
297
|
|
293
298
|
x = attn(x)
|
294
299
|
|
@@ -300,7 +305,7 @@ class MemoryAsContextTransformer(Module):
|
|
300
305
|
|
301
306
|
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
302
307
|
|
303
|
-
x,
|
308
|
+
x, _ = unpack(x, mem_ps, 'b * d')
|
304
309
|
|
305
310
|
x = inverse_segment(x)
|
306
311
|
|
@@ -27,9 +27,7 @@ n - sequence
|
|
27
27
|
d - feature dimension
|
28
28
|
c - intra-chunk
|
29
29
|
"""
|
30
|
-
|
31
|
-
# constants
|
32
|
-
|
30
|
+
7
|
33
31
|
LinearNoBias = partial(Linear, bias = False)
|
34
32
|
|
35
33
|
# functions
|
@@ -390,7 +388,10 @@ class NeuralMemory(Module):
|
|
390
388
|
|
391
389
|
padding = next_seq_len - curtailed_seq_len
|
392
390
|
|
393
|
-
|
391
|
+
needs_pad = padding > 0
|
392
|
+
|
393
|
+
if needs_pad:
|
394
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
394
395
|
|
395
396
|
# the parameters of the memory model stores the memories of the key / values
|
396
397
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
@@ -442,7 +443,9 @@ class NeuralMemory(Module):
|
|
442
443
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
443
444
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
444
445
|
|
445
|
-
|
446
|
+
if needs_pad:
|
447
|
+
values = values[:, :-padding]
|
448
|
+
|
446
449
|
return values
|
447
450
|
|
448
451
|
def forward(
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|