titans-pytorch 0.0.29__tar.gz → 0.0.31__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/PKG-INFO +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/pyproject.toml +1 -1
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/tests/test_titans.py +8 -2
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/mac_transformer.py +76 -38
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/.gitignore +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/LICENSE +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/data/README.md +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/fig1.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/fig2.png +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/requirements.txt +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.29 → titans_pytorch-0.0.31}/train.py +0 -0
@@ -34,14 +34,20 @@ def test_titans_attn_memory():
|
|
34
34
|
|
35
35
|
assert seq.shape == retrieved.shape
|
36
36
|
|
37
|
-
|
37
|
+
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
38
|
+
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
39
|
+
def test_mac(
|
40
|
+
num_persist_mem_tokens,
|
41
|
+
num_longterm_mem_tokens
|
42
|
+
):
|
38
43
|
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
39
44
|
|
40
45
|
transformer = MemoryAsContextTransformer(
|
41
46
|
num_tokens = 256,
|
42
47
|
dim = 256,
|
43
48
|
depth = 2,
|
44
|
-
num_persist_mem_tokens =
|
49
|
+
num_persist_mem_tokens = num_persist_mem_tokens,
|
50
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
45
51
|
segment_len = 128,
|
46
52
|
)
|
47
53
|
|
@@ -7,10 +7,9 @@ from torch import nn, cat
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
9
9
|
|
10
|
-
from einops import repeat
|
10
|
+
from einops import repeat, rearrange
|
11
11
|
from einops.layers.torch import Rearrange
|
12
12
|
|
13
|
-
|
14
13
|
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
15
14
|
|
16
15
|
# absolute and relative positions
|
@@ -30,9 +29,33 @@ def exists(v):
|
|
30
29
|
def default(v, d):
|
31
30
|
return v if exists(v) else d
|
32
31
|
|
32
|
+
def identity(t):
|
33
|
+
return t
|
34
|
+
|
33
35
|
def round_up_multiple(seq, mult):
|
34
36
|
return ceil(seq / mult) * mult
|
35
37
|
|
38
|
+
def pad_and_segment_with_inverse(seq, segment_len):
|
39
|
+
batch, seq_len = seq.shape[:2]
|
40
|
+
|
41
|
+
need_segment = seq_len >= segment_len
|
42
|
+
|
43
|
+
if not need_segment:
|
44
|
+
return seq, identity
|
45
|
+
|
46
|
+
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
47
|
+
|
48
|
+
padding = next_seq_len_mult - seq_len
|
49
|
+
seq = F.pad(seq, (0, 0, 0, padding))
|
50
|
+
|
51
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
52
|
+
|
53
|
+
def inverse(out):
|
54
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
55
|
+
return out[:, :-padding]
|
56
|
+
|
57
|
+
return seq, inverse
|
58
|
+
|
36
59
|
# feedforward and attention
|
37
60
|
|
38
61
|
class GEGLU(Module):
|
@@ -55,7 +78,8 @@ class SegmentedAttention(Module):
|
|
55
78
|
self,
|
56
79
|
dim,
|
57
80
|
segment_len,
|
58
|
-
num_persist_mem_tokens,
|
81
|
+
num_persist_mem_tokens = 0,
|
82
|
+
num_longterm_mem_tokens = 0,
|
59
83
|
dim_head = 64,
|
60
84
|
heads = 8,
|
61
85
|
):
|
@@ -70,31 +94,25 @@ class SegmentedAttention(Module):
|
|
70
94
|
self.to_out = LinearNoBias(dim_inner, dim)
|
71
95
|
|
72
96
|
self.segment_len = segment_len
|
97
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
98
|
+
|
99
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
73
100
|
|
74
101
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
75
102
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
76
103
|
|
77
|
-
self.segment_seq = Rearrange('b (n w) d -> (b n) w d', n = segment_len)
|
78
|
-
self.merge_seq_back = Rearrange('(b n) w d -> b (n w) d', n = segment_len)
|
79
|
-
|
80
104
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
81
105
|
|
82
106
|
def forward(self, seq):
|
107
|
+
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
108
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
109
|
+
|
83
110
|
batch, seq_len = seq.shape[:2]
|
84
111
|
|
85
112
|
# auto pad to multiple
|
86
113
|
# todo - get rid of logic with flex attention
|
87
114
|
|
88
|
-
|
89
|
-
|
90
|
-
if need_segment:
|
91
|
-
next_seq_len = round_up_multiple(seq_len, self.segment_len)
|
92
|
-
padding = next_seq_len - seq_len
|
93
|
-
|
94
|
-
if padding > 0:
|
95
|
-
seq = F.pad(seq, (0, 0, 0, padding))
|
96
|
-
|
97
|
-
seq = self.segment_seq(seq)
|
115
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
98
116
|
|
99
117
|
# attention
|
100
118
|
|
@@ -124,10 +142,9 @@ class SegmentedAttention(Module):
|
|
124
142
|
|
125
143
|
out = self.to_out(out)
|
126
144
|
|
127
|
-
|
128
|
-
out = self.merge_seq_back(out)
|
145
|
+
out = inverse_segment(out)
|
129
146
|
|
130
|
-
return out
|
147
|
+
return out
|
131
148
|
|
132
149
|
# MAC transformer
|
133
150
|
|
@@ -139,7 +156,8 @@ class MemoryAsContextTransformer(Module):
|
|
139
156
|
dim,
|
140
157
|
depth,
|
141
158
|
segment_len,
|
142
|
-
|
159
|
+
num_longterm_mem_tokens = 0,
|
160
|
+
num_persist_mem_tokens = 0,
|
143
161
|
dim_head = 64,
|
144
162
|
heads = 8,
|
145
163
|
ff_mult = 4,
|
@@ -147,10 +165,18 @@ class MemoryAsContextTransformer(Module):
|
|
147
165
|
):
|
148
166
|
super().__init__()
|
149
167
|
|
150
|
-
self.
|
168
|
+
self.token_emb = nn.Embedding(num_tokens, dim)
|
169
|
+
|
151
170
|
self.axial_pos_emb = ContinuousAxialPositionalEmbedding(dim = dim, num_axial_dims = 2)
|
152
171
|
|
153
|
-
|
172
|
+
# long term mem tokens
|
173
|
+
|
174
|
+
self.segment_len = segment_len
|
175
|
+
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
176
|
+
|
177
|
+
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
178
|
+
|
179
|
+
# hyper conection
|
154
180
|
|
155
181
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
156
182
|
|
@@ -162,6 +188,7 @@ class MemoryAsContextTransformer(Module):
|
|
162
188
|
dim_head = dim_head,
|
163
189
|
heads = heads,
|
164
190
|
segment_len = segment_len,
|
191
|
+
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
165
192
|
num_persist_mem_tokens = num_persist_mem_tokens
|
166
193
|
)
|
167
194
|
|
@@ -177,16 +204,32 @@ class MemoryAsContextTransformer(Module):
|
|
177
204
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
178
205
|
|
179
206
|
def forward(self, x):
|
180
|
-
|
207
|
+
|
208
|
+
# math
|
209
|
+
|
210
|
+
batch, seq_len, segment_len, num_longterm_mem_tokens= *x.shape, self.segment_len, self.num_longterm_mem_tokens
|
211
|
+
|
181
212
|
windows = ceil(seq_len / segment_len)
|
213
|
+
total_segment_len = segment_len + num_longterm_mem_tokens
|
214
|
+
|
215
|
+
# token embedding
|
182
216
|
|
183
217
|
x = self.token_emb(x)
|
184
218
|
|
219
|
+
# intersperse longterm memory
|
220
|
+
|
221
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, segment_len)
|
222
|
+
|
223
|
+
mems = repeat(self.longterm_mems, 'n d -> b n d', b = x.shape[0])
|
224
|
+
x = torch.cat((mems, x), dim = -2)
|
225
|
+
|
226
|
+
x = inverse_segment(x)
|
227
|
+
|
185
228
|
# apply axial positional embedding
|
186
229
|
# so intra and inter segment can be more easily discerned by the network
|
187
230
|
|
188
|
-
pos_emb = self.axial_pos_emb((windows,
|
189
|
-
x = x + pos_emb[:
|
231
|
+
pos_emb = self.axial_pos_emb((windows, total_segment_len), flatten = True)
|
232
|
+
x = x + pos_emb[:x.shape[-2]]
|
190
233
|
|
191
234
|
# expand and reduce streams for hyper connections
|
192
235
|
|
@@ -198,21 +241,16 @@ class MemoryAsContextTransformer(Module):
|
|
198
241
|
|
199
242
|
x = self.reduce_streams(x)
|
200
243
|
|
201
|
-
|
244
|
+
# excise out the memories
|
202
245
|
|
203
|
-
|
246
|
+
x, inverse_segment = pad_and_segment_with_inverse(x, total_segment_len)
|
204
247
|
|
205
|
-
|
248
|
+
x = x[:, self.num_longterm_mem_tokens:]
|
206
249
|
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
dim = 256,
|
211
|
-
depth = 2,
|
212
|
-
num_persist_mem_tokens = 16,
|
213
|
-
segment_len = 128,
|
214
|
-
)
|
250
|
+
x = inverse_segment(x)
|
251
|
+
|
252
|
+
# to logits
|
215
253
|
|
216
|
-
|
254
|
+
x = self.norm(x)
|
217
255
|
|
218
|
-
|
256
|
+
return self.to_logits(x)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|