titans-pytorch 0.0.63__py3-none-any.whl → 0.0.65__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +77 -20
- {titans_pytorch-0.0.63.dist-info → titans_pytorch-0.0.65.dist-info}/METADATA +1 -1
- titans_pytorch-0.0.65.dist-info/RECORD +8 -0
- titans_pytorch-0.0.63.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.63.dist-info → titans_pytorch-0.0.65.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.63.dist-info → titans_pytorch-0.0.65.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,13 +20,21 @@ try:
|
|
|
20
20
|
except ImportError:
|
|
21
21
|
pass
|
|
22
22
|
|
|
23
|
-
def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
23
|
+
def create_mac_block_mask(seq_len, window_size, persist_mem_len, sliding = False):
|
|
24
24
|
|
|
25
|
-
def create_mac_mask(
|
|
25
|
+
def create_mac_mask(_, __, q_idx, kv_idx):
|
|
26
26
|
is_persist_mem = kv_idx < persist_mem_len
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
kv_without_mem = kv_idx - persist_mem_len
|
|
28
|
+
causal_mask = q_idx >= kv_without_mem
|
|
29
|
+
|
|
30
|
+
if not sliding:
|
|
31
|
+
block_diagonal = (q_idx // window_size) == (kv_without_mem // window_size)
|
|
32
|
+
causal_mask = causal_mask & block_diagonal
|
|
33
|
+
else:
|
|
34
|
+
sliding_mask = (q_idx - kv_without_mem) <= window_size
|
|
35
|
+
causal_mask = causal_mask & sliding_mask
|
|
36
|
+
|
|
37
|
+
return is_persist_mem | (~is_persist_mem & causal_mask)
|
|
30
38
|
|
|
31
39
|
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
32
40
|
return block_mask
|
|
@@ -73,7 +81,12 @@ def identity(t):
|
|
|
73
81
|
def round_up_multiple(seq, mult):
|
|
74
82
|
return ceil(seq / mult) * mult
|
|
75
83
|
|
|
76
|
-
def
|
|
84
|
+
def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
85
|
+
dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
|
|
86
|
+
zeros = ((0, 0) * dims_from_right)
|
|
87
|
+
return F.pad(t, (*zeros, *pad), value = value)
|
|
88
|
+
|
|
89
|
+
def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
77
90
|
batch, seq_len = seq.shape[:2]
|
|
78
91
|
|
|
79
92
|
need_segment = seq_len >= segment_len
|
|
@@ -89,13 +102,15 @@ def pad_and_segment_with_inverse(seq, segment_len):
|
|
|
89
102
|
if needs_pad:
|
|
90
103
|
seq = F.pad(seq, (0, 0, 0, padding))
|
|
91
104
|
|
|
92
|
-
|
|
105
|
+
if fold_into_batch:
|
|
106
|
+
seq = rearrange(seq, 'b (w n) d -> (b w) n d', n = segment_len)
|
|
93
107
|
|
|
94
108
|
def inverse(out):
|
|
95
|
-
|
|
109
|
+
if fold_into_batch:
|
|
110
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
|
96
111
|
|
|
97
112
|
if needs_pad:
|
|
98
|
-
out = out[
|
|
113
|
+
out = out[..., :-padding, :]
|
|
99
114
|
|
|
100
115
|
return out
|
|
101
116
|
|
|
@@ -127,6 +142,7 @@ class SegmentedAttention(Module):
|
|
|
127
142
|
num_longterm_mem_tokens = 0,
|
|
128
143
|
dim_head = 64,
|
|
129
144
|
heads = 8,
|
|
145
|
+
sliding = False,
|
|
130
146
|
accept_value_residual = False,
|
|
131
147
|
attend_kwargs: dict = dict(),
|
|
132
148
|
use_flex_attn = False
|
|
@@ -153,6 +169,9 @@ class SegmentedAttention(Module):
|
|
|
153
169
|
self.num_longterm_mem_tokens = num_longterm_mem_tokens
|
|
154
170
|
|
|
155
171
|
total_segment_len = segment_len + num_longterm_mem_tokens
|
|
172
|
+
self.total_segment_len = total_segment_len
|
|
173
|
+
|
|
174
|
+
self.sliding = sliding # sliding window attn - doubt their non-sliding results being the best. local attention with overlapping windows is very strong
|
|
156
175
|
|
|
157
176
|
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
|
158
177
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
@@ -209,7 +228,7 @@ class SegmentedAttention(Module):
|
|
|
209
228
|
# prep flex attention
|
|
210
229
|
|
|
211
230
|
if not exists(flex_attn_fn):
|
|
212
|
-
block_mask = create_mac_block_mask(seq_len, self.
|
|
231
|
+
block_mask = create_mac_block_mask(seq_len, self.total_segment_len, self.num_persist_mem_tokens, self.sliding)
|
|
213
232
|
|
|
214
233
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
215
234
|
|
|
@@ -241,9 +260,8 @@ class SegmentedAttention(Module):
|
|
|
241
260
|
batch, seq_len = seq.shape[:2]
|
|
242
261
|
|
|
243
262
|
# auto pad to multiple
|
|
244
|
-
# todo - get rid of logic with flex attention
|
|
245
263
|
|
|
246
|
-
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len)
|
|
264
|
+
seq, inverse_segment = pad_and_segment_with_inverse(seq, total_segment_len, fold_into_batch = False)
|
|
247
265
|
|
|
248
266
|
# attention
|
|
249
267
|
|
|
@@ -260,14 +278,45 @@ class SegmentedAttention(Module):
|
|
|
260
278
|
mix = self.to_learned_v_mix(seq)
|
|
261
279
|
v = v.lerp(value_residual, mix)
|
|
262
280
|
|
|
263
|
-
# take care of persistent memory key / values
|
|
264
|
-
|
|
265
|
-
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
|
|
266
|
-
|
|
267
281
|
# relative positions
|
|
268
282
|
|
|
269
283
|
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
270
284
|
|
|
285
|
+
# fold
|
|
286
|
+
|
|
287
|
+
q, k, v = tuple(rearrange(t, 'b h (w n) d -> (b w) h n d', n = total_segment_len) for t in (q, k, v))
|
|
288
|
+
|
|
289
|
+
# maybe sliding for cpu
|
|
290
|
+
|
|
291
|
+
attend_kwargs = dict()
|
|
292
|
+
|
|
293
|
+
if self.sliding:
|
|
294
|
+
k, v = tuple(rearrange(t, '(b w) ... -> b w ...', b = batch) for t in (k, v))
|
|
295
|
+
k, v = tuple(pad_at_dim(t, (1, 0), value = 0., dim = 1) for t in (k, v))
|
|
296
|
+
k = cat((k[:, :-1], k[:, 1:]), dim = -2)
|
|
297
|
+
v = cat((v[:, :-1], v[:, 1:]), dim = -2)
|
|
298
|
+
k, v = tuple(rearrange(t, 'b w ... -> (b w) ...') for t in (k, v))
|
|
299
|
+
|
|
300
|
+
# take care of masking
|
|
301
|
+
|
|
302
|
+
idx = torch.arange(seq.shape[-2], device = seq.device)
|
|
303
|
+
q_idx = rearrange(idx, '(w n) -> w n', n = total_segment_len)
|
|
304
|
+
k_idx = pad_at_dim(q_idx, (1, 0), dim = 0, value = -1e4)
|
|
305
|
+
k_idx = cat((k_idx[:-1], k_idx[1:]), dim = -1)
|
|
306
|
+
|
|
307
|
+
q_idx = rearrange(q_idx, 'w i -> w i 1')
|
|
308
|
+
k_idx = rearrange(k_idx, 'w j -> w 1 j')
|
|
309
|
+
|
|
310
|
+
sliding_mask = (q_idx - k_idx) <= total_segment_len
|
|
311
|
+
sliding_mask = F.pad(sliding_mask, (self.num_persist_mem_tokens, 0), value = True)
|
|
312
|
+
|
|
313
|
+
sliding_mask = repeat(sliding_mask, 'w i j -> (b w) 1 i j', b = batch)
|
|
314
|
+
attend_kwargs.update(mask = sliding_mask)
|
|
315
|
+
|
|
316
|
+
# take care of persistent memory key / values
|
|
317
|
+
|
|
318
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
|
|
319
|
+
|
|
271
320
|
# persistent memory
|
|
272
321
|
|
|
273
322
|
k = cat((pmk, k), dim = -2)
|
|
@@ -275,12 +324,14 @@ class SegmentedAttention(Module):
|
|
|
275
324
|
|
|
276
325
|
# attention
|
|
277
326
|
|
|
278
|
-
out, _ = self.attend(q, k, v)
|
|
327
|
+
out, _ = self.attend(q, k, v, **attend_kwargs)
|
|
279
328
|
|
|
280
329
|
out = self.merge_heads(out)
|
|
281
330
|
|
|
282
331
|
out = self.to_out(out)
|
|
283
332
|
|
|
333
|
+
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
|
334
|
+
|
|
284
335
|
out = inverse_segment(out)
|
|
285
336
|
|
|
286
337
|
return out, orig_v
|
|
@@ -349,7 +400,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
349
400
|
neural_memory_kwargs: dict = dict(),
|
|
350
401
|
neural_memory_layers: tuple[int, ...] | None = None,
|
|
351
402
|
aux_kv_recon_loss_weight = 0.,
|
|
352
|
-
use_flex_attn = False
|
|
403
|
+
use_flex_attn = False,
|
|
404
|
+
sliding_window_attn = False
|
|
353
405
|
):
|
|
354
406
|
super().__init__()
|
|
355
407
|
|
|
@@ -366,6 +418,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
366
418
|
|
|
367
419
|
self.longterm_mems = nn.Parameter(torch.randn(num_longterm_mem_tokens, dim) * 0.02)
|
|
368
420
|
|
|
421
|
+
# maybe sliding window attn
|
|
422
|
+
|
|
423
|
+
self.sliding_window_attn = sliding_window_attn
|
|
424
|
+
|
|
369
425
|
# hyper conection
|
|
370
426
|
|
|
371
427
|
init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
|
|
@@ -396,7 +452,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
396
452
|
use_flex_attn = use_flex_attn,
|
|
397
453
|
accept_value_residual = not is_first,
|
|
398
454
|
num_longterm_mem_tokens = num_longterm_mem_tokens,
|
|
399
|
-
num_persist_mem_tokens = num_persist_mem_tokens
|
|
455
|
+
num_persist_mem_tokens = num_persist_mem_tokens,
|
|
456
|
+
sliding = sliding_window_attn
|
|
400
457
|
)
|
|
401
458
|
|
|
402
459
|
mem = None
|
|
@@ -489,7 +546,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
489
546
|
flex_attn_fn = None
|
|
490
547
|
|
|
491
548
|
if use_flex_attn:
|
|
492
|
-
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens)
|
|
549
|
+
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
|
|
493
550
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
494
551
|
|
|
495
552
|
# value residual
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=7PHBCbeB1LhHY5s3zAyYF0L3Mm7CNy4TOBbcpLX6bNE,17686
|
|
4
|
+
titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
|
|
5
|
+
titans_pytorch-0.0.65.dist-info/METADATA,sha256=oDjEiufwOninsFDoCGbu691LXc1mey2OT7j6PNzkz0Q,4457
|
|
6
|
+
titans_pytorch-0.0.65.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.65.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.65.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=I-4oF1JPEmUvhLszEocM5cDgHYNFcNp0Q9nbDYSPFqU,195
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=R9Xu-BjqAe9ZY60IGk4aNXBx_L8THsjJ4QrkbTnRNHo,15346
|
|
4
|
-
titans_pytorch/titans.py,sha256=y6lJRErIoM6T2aTVFlf1GxSB0cpsmBZdSIj1DCHUCQ8,17486
|
|
5
|
-
titans_pytorch-0.0.63.dist-info/METADATA,sha256=-CImQ-4hVNDFWczTb0V1dWL0QkHS-1c6XyntI1ULrms,4457
|
|
6
|
-
titans_pytorch-0.0.63.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.63.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.63.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|