titans-pytorch 0.0.51__tar.gz → 0.0.52__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/PKG-INFO +1 -1
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/pyproject.toml +1 -1
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/titans_pytorch/mac_transformer.py +71 -2
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/.gitignore +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/LICENSE +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/README.md +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/data/README.md +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/fig1.png +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/fig2.png +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/tests/test_titans.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.0.51 → titans_pytorch-0.0.52}/train_mac.py +0 -0
|
@@ -128,6 +128,7 @@ class SegmentedAttention(Module):
|
|
|
128
128
|
heads = 8,
|
|
129
129
|
accept_value_residual = False,
|
|
130
130
|
attend_kwargs: dict = dict(),
|
|
131
|
+
use_flex_attn = False
|
|
131
132
|
):
|
|
132
133
|
super().__init__()
|
|
133
134
|
self.norm = nn.RMSNorm(dim)
|
|
@@ -157,11 +158,79 @@ class SegmentedAttention(Module):
|
|
|
157
158
|
|
|
158
159
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
159
160
|
|
|
161
|
+
# flex attn related
|
|
162
|
+
|
|
163
|
+
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
164
|
+
self.use_flex_attn = use_flex_attn
|
|
165
|
+
|
|
166
|
+
self.segment_len = segment_len
|
|
167
|
+
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
168
|
+
|
|
169
|
+
def forward_flex(
|
|
170
|
+
self,
|
|
171
|
+
seq,
|
|
172
|
+
value_residual = None,
|
|
173
|
+
flex_attn_fn: Callable | None = None
|
|
174
|
+
):
|
|
175
|
+
|
|
176
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
177
|
+
|
|
178
|
+
batch, seq_len = seq.shape[:2]
|
|
179
|
+
|
|
180
|
+
# attention
|
|
181
|
+
|
|
182
|
+
seq = self.norm(seq)
|
|
183
|
+
|
|
184
|
+
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
185
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
186
|
+
|
|
187
|
+
# value residual
|
|
188
|
+
|
|
189
|
+
orig_v = v
|
|
190
|
+
|
|
191
|
+
if exists(self.to_learned_v_mix):
|
|
192
|
+
mix = self.to_learned_v_mix(seq)
|
|
193
|
+
v = v.lerp(value_residual, mix)
|
|
194
|
+
|
|
195
|
+
# take care of persistent memory key / values
|
|
196
|
+
|
|
197
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
|
|
198
|
+
|
|
199
|
+
# relative positions
|
|
200
|
+
|
|
201
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
202
|
+
|
|
203
|
+
# persistent memory
|
|
204
|
+
|
|
205
|
+
k = cat((pmk, k), dim = -2)
|
|
206
|
+
v = cat((pmv, v), dim = -2)
|
|
207
|
+
|
|
208
|
+
# prep flex attention
|
|
209
|
+
|
|
210
|
+
if not exists(flex_attn_fn):
|
|
211
|
+
block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
|
|
212
|
+
|
|
213
|
+
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
214
|
+
|
|
215
|
+
# attention
|
|
216
|
+
|
|
217
|
+
out = flex_attn_fn(q, k, v)
|
|
218
|
+
|
|
219
|
+
out = self.merge_heads(out)
|
|
220
|
+
|
|
221
|
+
out = self.to_out(out)
|
|
222
|
+
|
|
223
|
+
return out, orig_v
|
|
224
|
+
|
|
160
225
|
def forward(
|
|
161
226
|
self,
|
|
162
227
|
seq,
|
|
163
|
-
value_residual = None
|
|
228
|
+
value_residual = None,
|
|
229
|
+
flex_attn_fn: Callable | None = None
|
|
164
230
|
):
|
|
231
|
+
if seq.is_cuda and self.use_flex_attn:
|
|
232
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
233
|
+
|
|
165
234
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
166
235
|
|
|
167
236
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
@@ -191,7 +260,7 @@ class SegmentedAttention(Module):
|
|
|
191
260
|
|
|
192
261
|
# take care of persistent memory key / values
|
|
193
262
|
|
|
194
|
-
pmk, pmv =
|
|
263
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
|
|
195
264
|
|
|
196
265
|
# relative positions
|
|
197
266
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|