titans-pytorch 0.0.26__tar.gz → 0.0.29__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/PKG-INFO +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/pyproject.toml +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/tests/test_titans.py +16 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/mac_transformer.py +52 -4
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.gitignore +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/LICENSE +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig1.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig2.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/requirements.txt +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/train.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.29
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "titans-pytorch"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.29"
|
4
4
|
description = "Titans"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -26,10 +26,12 @@ classifiers=[
|
|
26
26
|
|
27
27
|
dependencies = [
|
28
28
|
"accelerated-scan>=0.2.0",
|
29
|
+
"axial_positional_embedding>=0.3.5",
|
29
30
|
"einx>=0.3.0",
|
30
31
|
"einops>=0.8.0",
|
31
32
|
"hyper-connections>=0.1.8",
|
32
33
|
"Ninja",
|
34
|
+
"rotary-embedding-torch",
|
33
35
|
"tensordict",
|
34
36
|
"torch>=2.2",
|
35
37
|
]
|
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
|
|
33
33
|
retrieved = mem(seq)
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
|
+
|
37
|
+
def test_mac():
|
38
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
|
+
|
40
|
+
transformer = MemoryAsContextTransformer(
|
41
|
+
num_tokens = 256,
|
42
|
+
dim = 256,
|
43
|
+
depth = 2,
|
44
|
+
num_persist_mem_tokens = 16,
|
45
|
+
segment_len = 128,
|
46
|
+
)
|
47
|
+
|
48
|
+
x = torch.randint(0, 256, (1, 1023))
|
49
|
+
|
50
|
+
logits = transformer(x)
|
51
|
+
assert logits.shape == (1, 1023, 256)
|
@@ -1,16 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn
|
6
|
+
from torch import nn, cat
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
+
from einops import repeat
|
10
11
|
from einops.layers.torch import Rearrange
|
11
12
|
|
13
|
+
|
12
14
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
13
15
|
|
16
|
+
# absolute and relative positions
|
17
|
+
|
18
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
19
|
+
from rotary_embedding_torch import RotaryEmbedding
|
20
|
+
|
14
21
|
# constants
|
15
22
|
|
16
23
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -24,7 +31,7 @@ def default(v, d):
|
|
24
31
|
return v if exists(v) else d
|
25
32
|
|
26
33
|
def round_up_multiple(seq, mult):
|
27
|
-
return
|
34
|
+
return ceil(seq / mult) * mult
|
28
35
|
|
29
36
|
# feedforward and attention
|
30
37
|
|
@@ -48,6 +55,7 @@ class SegmentedAttention(Module):
|
|
48
55
|
self,
|
49
56
|
dim,
|
50
57
|
segment_len,
|
58
|
+
num_persist_mem_tokens,
|
51
59
|
dim_head = 64,
|
52
60
|
heads = 8,
|
53
61
|
):
|
@@ -56,6 +64,8 @@ class SegmentedAttention(Module):
|
|
56
64
|
|
57
65
|
dim_inner = dim_head * heads
|
58
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
59
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
60
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
61
71
|
|
@@ -67,6 +77,7 @@ class SegmentedAttention(Module):
|
|
67
77
|
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
68
78
|
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
69
79
|
|
80
|
+
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
70
81
|
|
71
82
|
def forward(self, seq):
|
72
83
|
batch, seq_len = seq.shape[:2]
|
@@ -92,6 +103,21 @@ class SegmentedAttention(Module):
|
|
92
103
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
93
104
|
q, k, v = map(self.split_heads, (q, k, v))
|
94
105
|
|
106
|
+
# take care of persistent memory key / values
|
107
|
+
|
108
|
+
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
109
|
+
|
110
|
+
# relative positions
|
111
|
+
|
112
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
113
|
+
|
114
|
+
# persistent memory
|
115
|
+
|
116
|
+
k = cat((pmk, k), dim = -2)
|
117
|
+
v = cat((pmv, v), dim = -2)
|
118
|
+
|
119
|
+
# sdpa
|
120
|
+
|
95
121
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
96
122
|
|
97
123
|
out = self.merge_heads(out)
|
@@ -113,6 +139,7 @@ class MemoryAsContextTransformer(Module):
|
|
113
139
|
dim,
|
114
140
|
depth,
|
115
141
|
segment_len,
|
142
|
+
num_persist_mem_tokens,
|
116
143
|
dim_head = 64,
|
117
144
|
heads = 8,
|
118
145
|
ff_mult = 4,
|
@@ -120,6 +147,9 @@ class MemoryAsContextTransformer(Module):
|
|
120
147
|
):
|
121
148
|
super().__init__()
|
122
149
|
|
150
|
+
self.segment_len = segment_len
|
151
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
|
+
|
123
153
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
124
154
|
|
125
155
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
@@ -127,7 +157,14 @@ class MemoryAsContextTransformer(Module):
|
|
127
157
|
self.layers = ModuleList([])
|
128
158
|
|
129
159
|
for _ in range(depth):
|
130
|
-
attn = SegmentedAttention(
|
160
|
+
attn = SegmentedAttention(
|
161
|
+
dim = dim,
|
162
|
+
dim_head = dim_head,
|
163
|
+
heads = heads,
|
164
|
+
segment_len = segment_len,
|
165
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
166
|
+
)
|
167
|
+
|
131
168
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
132
169
|
|
133
170
|
self.layers.append(ModuleList([
|
@@ -140,9 +177,19 @@ class MemoryAsContextTransformer(Module):
|
|
140
177
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
141
178
|
|
142
179
|
def forward(self, x):
|
180
|
+
seq_len, segment_len = x.shape[-1], self.segment_len
|
181
|
+
windows = ceil(seq_len / segment_len)
|
143
182
|
|
144
183
|
x = self.token_emb(x)
|
145
184
|
|
185
|
+
# apply axial positional embedding
|
186
|
+
# so intra and inter segment can be more easily discerned by the network
|
187
|
+
|
188
|
+
pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
|
189
|
+
x = x + pos_emb[:seq_len]
|
190
|
+
|
191
|
+
# expand and reduce streams for hyper connections
|
192
|
+
|
146
193
|
x = self.expand_streams(x)
|
147
194
|
|
148
195
|
for attn, ff in self.layers:
|
@@ -162,6 +209,7 @@ if __name__ == '__main__':
|
|
162
209
|
num_tokens = 256,
|
163
210
|
dim = 256,
|
164
211
|
depth = 2,
|
212
|
+
num_persist_mem_tokens = 16,
|
165
213
|
segment_len = 128,
|
166
214
|
)
|
167
215
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|