titans-pytorch 0.0.26__tar.gz → 0.0.29__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/PKG-INFO +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/pyproject.toml +3 -1
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/tests/test_titans.py +16 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/mac_transformer.py +52 -4
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/.gitignore +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/LICENSE +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/README.md +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig1.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/fig2.png +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/requirements.txt +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.26 → titans_pytorch-0.0.29}/train.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.0.
|
3
|
+
Version: 0.0.29
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -35,10 +35,12 @@ Classifier: Programming Language :: Python :: 3.9
|
|
35
35
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
36
36
|
Requires-Python: >=3.9
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
38
|
+
Requires-Dist: axial-positional-embedding>=0.3.5
|
38
39
|
Requires-Dist: einops>=0.8.0
|
39
40
|
Requires-Dist: einx>=0.3.0
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
41
42
|
Requires-Dist: ninja
|
43
|
+
Requires-Dist: rotary-embedding-torch
|
42
44
|
Requires-Dist: tensordict
|
43
45
|
Requires-Dist: torch>=2.2
|
44
46
|
Provides-Extra: examples
|
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "titans-pytorch"
|
3
|
-
version = "0.0.
|
3
|
+
version = "0.0.29"
|
4
4
|
description = "Titans"
|
5
5
|
authors = [
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
@@ -26,10 +26,12 @@ classifiers=[
|
|
26
26
|
|
27
27
|
dependencies = [
|
28
28
|
"accelerated-scan>=0.2.0",
|
29
|
+
"axial_positional_embedding>=0.3.5",
|
29
30
|
"einx>=0.3.0",
|
30
31
|
"einops>=0.8.0",
|
31
32
|
"hyper-connections>=0.1.8",
|
32
33
|
"Ninja",
|
34
|
+
"rotary-embedding-torch",
|
33
35
|
"tensordict",
|
34
36
|
"torch>=2.2",
|
35
37
|
]
|
@@ -33,3 +33,19 @@ def test_titans_attn_memory():
|
|
33
33
|
retrieved = mem(seq)
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
|
+
|
37
|
+
def test_mac():
|
38
|
+
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
|
+
|
40
|
+
transformer = MemoryAsContextTransformer(
|
41
|
+
num_tokens = 256,
|
42
|
+
dim = 256,
|
43
|
+
depth = 2,
|
44
|
+
num_persist_mem_tokens = 16,
|
45
|
+
segment_len = 128,
|
46
|
+
)
|
47
|
+
|
48
|
+
x = torch.randint(0, 256, (1, 1023))
|
49
|
+
|
50
|
+
logits = transformer(x)
|
51
|
+
assert logits.shape == (1, 1023, 256)
|
@@ -1,16 +1,23 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
import
|
2
|
+
from math import ceil
|
3
3
|
from functools import partial
|
4
4
|
|
5
5
|
import torch
|
6
|
-
from torch import nn
|
6
|
+
from torch import nn, cat
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
+
from einops import repeat
|
10
11
|
from einops.layers.torch import Rearrange
|
11
12
|
|
13
|
+
|
12
14
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
13
15
|
|
16
|
+
# absolute and relative positions
|
17
|
+
|
18
|
+
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
19
|
+
from rotary_embedding_torch import RotaryEmbedding
|
20
|
+
|
14
21
|
# constants
|
15
22
|
|
16
23
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -24,7 +31,7 @@ def default(v, d):
|
|
24
31
|
return v if exists(v) else d
|
25
32
|
|
26
33
|
def round_up_multiple(seq, mult):
|
27
|
-
return
|
34
|
+
return ceil(seq / mult) * mult
|
28
35
|
|
29
36
|
# feedforward and attention
|
30
37
|
|
@@ -48,6 +55,7 @@ class SegmentedAttention(Module):
|
|
48
55
|
self,
|
49
56
|
dim,
|
50
57
|
segment_len,
|
58
|
+
num_persist_mem_tokens,
|
51
59
|
dim_head = 64,
|
52
60
|
heads = 8,
|
53
61
|
):
|
@@ -56,6 +64,8 @@ class SegmentedAttention(Module):
|
|
56
64
|
|
57
65
|
dim_inner = dim_head * heads
|
58
66
|
|
67
|
+
self.rotary_emb = RotaryEmbedding(dim_head)
|
68
|
+
|
59
69
|
self.to_qkv = LinearNoBias(dim, dim_inner * 3)
|
60
70
|
self.to_out = LinearNoBias(dim_inner, dim)
|
61
71
|
|
@@ -67,6 +77,7 @@ class SegmentedAttention(Module):
|
|
67
77
|
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
68
78
|
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
69
79
|
|
80
|
+
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
70
81
|
|
71
82
|
def forward(self, seq):
|
72
83
|
batch, seq_len = seq.shape[:2]
|
@@ -92,6 +103,21 @@ class SegmentedAttention(Module):
|
|
92
103
|
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
93
104
|
q, k, v = map(self.split_heads, (q, k, v))
|
94
105
|
|
106
|
+
# take care of persistent memory key / values
|
107
|
+
|
108
|
+
pmk, pmv = tuple(repeat(t, 'h n d -> b h n d', b = seq.shape[0]) for t in self.persistent_memory)
|
109
|
+
|
110
|
+
# relative positions
|
111
|
+
|
112
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
113
|
+
|
114
|
+
# persistent memory
|
115
|
+
|
116
|
+
k = cat((pmk, k), dim = -2)
|
117
|
+
v = cat((pmv, v), dim = -2)
|
118
|
+
|
119
|
+
# sdpa
|
120
|
+
|
95
121
|
out = F.scaled_dot_product_attention(q, k, v, is_causal = True)
|
96
122
|
|
97
123
|
out = self.merge_heads(out)
|
@@ -113,6 +139,7 @@ class MemoryAsContextTransformer(Module):
|
|
113
139
|
dim,
|
114
140
|
depth,
|
115
141
|
segment_len,
|
142
|
+
num_persist_mem_tokens,
|
116
143
|
dim_head = 64,
|
117
144
|
heads = 8,
|
118
145
|
ff_mult = 4,
|
@@ -120,6 +147,9 @@ class MemoryAsContextTransformer(Module):
|
|
120
147
|
):
|
121
148
|
super().__init__()
|
122
149
|
|
150
|
+
self.segment_len = segment_len
|
151
|
+
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
|
+
|
123
153
|
self.token_emb = nn.Embedding(num_tokens, dim)
|
124
154
|
|
125
155
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
@@ -127,7 +157,14 @@ class MemoryAsContextTransformer(Module):
|
|
127
157
|
self.layers = ModuleList([])
|
128
158
|
|
129
159
|
for _ in range(depth):
|
130
|
-
attn = SegmentedAttention(
|
160
|
+
attn = SegmentedAttention(
|
161
|
+
dim = dim,
|
162
|
+
dim_head = dim_head,
|
163
|
+
heads = heads,
|
164
|
+
segment_len = segment_len,
|
165
|
+
num_persist_mem_tokens = num_persist_mem_tokens
|
166
|
+
)
|
167
|
+
|
131
168
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
132
169
|
|
133
170
|
self.layers.append(ModuleList([
|
@@ -140,9 +177,19 @@ class MemoryAsContextTransformer(Module):
|
|
140
177
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
141
178
|
|
142
179
|
def forward(self, x):
|
180
|
+
seq_len, segment_len = x.shape[-1], self.segment_len
|
181
|
+
windows = ceil(seq_len / segment_len)
|
143
182
|
|
144
183
|
x = self.token_emb(x)
|
145
184
|
|
185
|
+
# apply axial positional embedding
|
186
|
+
# so intra and inter segment can be more easily discerned by the network
|
187
|
+
|
188
|
+
pos_emb = self.axial_pos_emb((windows, segment_len), flatten = True)
|
189
|
+
x = x + pos_emb[:seq_len]
|
190
|
+
|
191
|
+
# expand and reduce streams for hyper connections
|
192
|
+
|
146
193
|
x = self.expand_streams(x)
|
147
194
|
|
148
195
|
for attn, ff in self.layers:
|
@@ -162,6 +209,7 @@ if __name__ == '__main__':
|
|
162
209
|
num_tokens = 256,
|
163
210
|
dim = 256,
|
164
211
|
depth = 2,
|
212
|
+
num_persist_mem_tokens = 16,
|
165
213
|
segment_len = 128,
|
166
214
|
)
|
167
215
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|