titans-pytorch 0.1.17__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +152 -64
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.20.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.20.dist-info/RECORD +8 -0
- titans_pytorch-0.1.17.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.20.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.20.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
from math import ceil
|
|
4
5
|
from functools import partial
|
|
6
|
+
from collections import namedtuple
|
|
5
7
|
|
|
6
8
|
import tqdm
|
|
7
9
|
|
|
8
10
|
import torch
|
|
9
|
-
from torch import nn, cat
|
|
11
|
+
from torch import nn, stack, cat
|
|
10
12
|
import torch.nn.functional as F
|
|
11
13
|
from torch.nn import Module, ModuleList, Linear
|
|
12
14
|
|
|
@@ -69,6 +71,8 @@ from titans_pytorch.titans import NeuralMemory
|
|
|
69
71
|
|
|
70
72
|
LinearNoBias = partial(Linear, bias = False)
|
|
71
73
|
|
|
74
|
+
AttnIntermediates = namedtuple('AttnIntermediates', ('value_residual', 'cached_key_values'))
|
|
75
|
+
|
|
72
76
|
# helpers
|
|
73
77
|
|
|
74
78
|
def exists(v):
|
|
@@ -80,6 +84,9 @@ def default(v, d):
|
|
|
80
84
|
def identity(t):
|
|
81
85
|
return t
|
|
82
86
|
|
|
87
|
+
def divisible_by(num, den):
|
|
88
|
+
return (num % den) == 0
|
|
89
|
+
|
|
83
90
|
def round_up_multiple(seq, mult):
|
|
84
91
|
return ceil(seq / mult) * mult
|
|
85
92
|
|
|
@@ -111,7 +118,7 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
|
111
118
|
|
|
112
119
|
def inverse(out):
|
|
113
120
|
if fold_into_batch:
|
|
114
|
-
out = rearrange(out, '(b w) n d -> b (w n) d', b = batch)
|
|
121
|
+
out = rearrange(out, '(b w) ... n d -> b ... (w n) d', b = batch)
|
|
115
122
|
|
|
116
123
|
if needs_pad:
|
|
117
124
|
out = out[..., :-padding, :]
|
|
@@ -213,11 +220,75 @@ class SegmentedAttention(Module):
|
|
|
213
220
|
self.segment_len = segment_len
|
|
214
221
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
215
222
|
|
|
223
|
+
def forward_inference(
|
|
224
|
+
self,
|
|
225
|
+
token,
|
|
226
|
+
cache,
|
|
227
|
+
value_residual = None,
|
|
228
|
+
output_gating = None,
|
|
229
|
+
):
|
|
230
|
+
batch = token.shape[0]
|
|
231
|
+
|
|
232
|
+
# attention
|
|
233
|
+
|
|
234
|
+
token = self.norm(token)
|
|
235
|
+
|
|
236
|
+
q, k, v = self.to_qkv(token).chunk(3, dim = -1)
|
|
237
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
238
|
+
|
|
239
|
+
# value residual
|
|
240
|
+
|
|
241
|
+
orig_v = v
|
|
242
|
+
|
|
243
|
+
if exists(self.to_learned_v_mix):
|
|
244
|
+
mix = self.to_learned_v_mix(token)
|
|
245
|
+
v = v.lerp(value_residual, mix)
|
|
246
|
+
|
|
247
|
+
# caching
|
|
248
|
+
|
|
249
|
+
ck, cv = cache
|
|
250
|
+
k = cat((ck, k), dim = -2)
|
|
251
|
+
v = cat((cv, v), dim = -2)
|
|
252
|
+
|
|
253
|
+
next_cache = (k, v)
|
|
254
|
+
|
|
255
|
+
# relative positions
|
|
256
|
+
|
|
257
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
258
|
+
|
|
259
|
+
# fold
|
|
260
|
+
|
|
261
|
+
q, k, v = tuple(rearrange(t, 'b h n d -> b h n d') for t in (q, k, v))
|
|
262
|
+
|
|
263
|
+
# take care of persistent memory key / values
|
|
264
|
+
|
|
265
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = k.shape[0])
|
|
266
|
+
|
|
267
|
+
# persistent memory
|
|
268
|
+
|
|
269
|
+
k = cat((pmk, k), dim = -2)
|
|
270
|
+
v = cat((pmv, v), dim = -2)
|
|
271
|
+
|
|
272
|
+
# attention
|
|
273
|
+
|
|
274
|
+
out, _ = self.attend(q, k, v)
|
|
275
|
+
|
|
276
|
+
out = self.merge_heads(out)
|
|
277
|
+
|
|
278
|
+
out = self.to_out(out)
|
|
279
|
+
|
|
280
|
+
if exists(output_gating):
|
|
281
|
+
out = out * output_gating
|
|
282
|
+
|
|
283
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
284
|
+
|
|
216
285
|
def forward_flex(
|
|
217
286
|
self,
|
|
218
287
|
seq,
|
|
219
288
|
value_residual = None,
|
|
220
|
-
flex_attn_fn: Callable | None = None
|
|
289
|
+
flex_attn_fn: Callable | None = None,
|
|
290
|
+
output_gating = None,
|
|
291
|
+
cache = None
|
|
221
292
|
):
|
|
222
293
|
|
|
223
294
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -239,6 +310,10 @@ class SegmentedAttention(Module):
|
|
|
239
310
|
mix = self.to_learned_v_mix(seq)
|
|
240
311
|
v = v.lerp(value_residual, mix)
|
|
241
312
|
|
|
313
|
+
# caching
|
|
314
|
+
|
|
315
|
+
next_cache = tuple(map(inverse_segment, (k, v)))
|
|
316
|
+
|
|
242
317
|
# take care of persistent memory key / values
|
|
243
318
|
|
|
244
319
|
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
|
|
@@ -267,17 +342,28 @@ class SegmentedAttention(Module):
|
|
|
267
342
|
|
|
268
343
|
out = self.to_out(out)
|
|
269
344
|
|
|
270
|
-
|
|
345
|
+
if exists(output_gating):
|
|
346
|
+
out = out * output_gating
|
|
347
|
+
|
|
348
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
271
349
|
|
|
272
350
|
def forward(
|
|
273
351
|
self,
|
|
274
352
|
seq,
|
|
275
353
|
value_residual = None,
|
|
276
354
|
flex_attn_fn: Callable | None = None,
|
|
277
|
-
disable_flex_attn = False
|
|
355
|
+
disable_flex_attn = False,
|
|
356
|
+
output_gating = None,
|
|
357
|
+
cache = None
|
|
278
358
|
):
|
|
359
|
+
is_inferencing = exists(cache)
|
|
360
|
+
|
|
361
|
+
if is_inferencing:
|
|
362
|
+
assert seq.shape[-2] == 1
|
|
363
|
+
return self.forward_inference(seq, cache, value_residual, output_gating = output_gating)
|
|
364
|
+
|
|
279
365
|
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
280
|
-
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
366
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating, cache = cache)
|
|
281
367
|
|
|
282
368
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
283
369
|
|
|
@@ -305,6 +391,10 @@ class SegmentedAttention(Module):
|
|
|
305
391
|
mix = self.to_learned_v_mix(seq)
|
|
306
392
|
v = v.lerp(value_residual, mix)
|
|
307
393
|
|
|
394
|
+
# caching
|
|
395
|
+
|
|
396
|
+
next_cache = tuple(map(inverse_segment, (k, v)))
|
|
397
|
+
|
|
308
398
|
# relative positions
|
|
309
399
|
|
|
310
400
|
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
@@ -361,50 +451,10 @@ class SegmentedAttention(Module):
|
|
|
361
451
|
|
|
362
452
|
out = inverse_segment(out)
|
|
363
453
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
367
|
-
|
|
368
|
-
class NeuralMemoryGatingWrapper(Module):
|
|
369
|
-
def __init__(
|
|
370
|
-
self,
|
|
371
|
-
dim,
|
|
372
|
-
attn: SegmentedAttention,
|
|
373
|
-
neural_mem: NeuralMemory | None = None,
|
|
374
|
-
gate_attn_output = True
|
|
375
|
-
):
|
|
376
|
-
super().__init__()
|
|
377
|
-
self.attn = attn
|
|
378
|
-
self.neural_mem = neural_mem
|
|
379
|
-
self.gate_attn_output = gate_attn_output
|
|
380
|
-
|
|
381
|
-
def forward(
|
|
382
|
-
self,
|
|
383
|
-
seq,
|
|
384
|
-
*args,
|
|
385
|
-
**kwargs
|
|
386
|
-
):
|
|
387
|
-
batch, seq_len = seq.shape[:2]
|
|
388
|
-
mem = self.neural_mem
|
|
389
|
-
|
|
390
|
-
if not exists(mem):
|
|
391
|
-
return self.attn(seq, *args, **kwargs), 0.
|
|
392
|
-
|
|
393
|
-
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
394
|
-
|
|
395
|
-
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
396
|
-
|
|
397
|
-
if not self.gate_attn_output:
|
|
398
|
-
seq = seq + retrieved
|
|
454
|
+
if exists(output_gating):
|
|
455
|
+
out = out * output_gating
|
|
399
456
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
403
|
-
|
|
404
|
-
if self.gate_attn_output:
|
|
405
|
-
attn_out = attn_out * retrieved.sigmoid()
|
|
406
|
-
|
|
407
|
-
return (attn_out, values), kv_aux_loss
|
|
457
|
+
return out, AttnIntermediates(orig_v, next_cache)
|
|
408
458
|
|
|
409
459
|
# MAC transformer
|
|
410
460
|
|
|
@@ -448,6 +498,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
448
498
|
# maybe sliding window attn
|
|
449
499
|
|
|
450
500
|
self.sliding_window_attn = sliding_window_attn
|
|
501
|
+
self.attn_window_size = segment_len + num_longterm_mem_tokens
|
|
451
502
|
|
|
452
503
|
# hyper conection
|
|
453
504
|
|
|
@@ -494,16 +545,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
494
545
|
**neural_memory_kwargs
|
|
495
546
|
)
|
|
496
547
|
|
|
497
|
-
attn = NeuralMemoryGatingWrapper(
|
|
498
|
-
dim,
|
|
499
|
-
attn = attn,
|
|
500
|
-
neural_mem = mem,
|
|
501
|
-
gate_attn_output = neural_mem_gate_attn_output
|
|
502
|
-
)
|
|
503
|
-
|
|
504
548
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
505
549
|
|
|
506
550
|
self.layers.append(ModuleList([
|
|
551
|
+
init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
|
|
507
552
|
init_hyper_conn(dim = dim, branch = attn),
|
|
508
553
|
init_hyper_conn(dim = dim, branch = ff)
|
|
509
554
|
]))
|
|
@@ -512,6 +557,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
512
557
|
|
|
513
558
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
|
514
559
|
|
|
560
|
+
# whether to gate the attention output with the retrieved memories
|
|
561
|
+
|
|
562
|
+
self.gate_attn_output = neural_mem_gate_attn_output
|
|
563
|
+
|
|
515
564
|
# auxiliary loss on kv recon
|
|
516
565
|
|
|
517
566
|
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
|
@@ -524,7 +573,6 @@ class MemoryAsContextTransformer(Module):
|
|
|
524
573
|
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
525
574
|
self.use_flex_attn = use_flex_attn
|
|
526
575
|
|
|
527
|
-
self.segment_len = segment_len
|
|
528
576
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
529
577
|
|
|
530
578
|
@torch.no_grad()
|
|
@@ -606,7 +654,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
606
654
|
|
|
607
655
|
# math
|
|
608
656
|
|
|
609
|
-
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens
|
|
657
|
+
batch, seq_len, neural_mem_segment_len, segment_len, num_longterm_mem_tokens, attn_window_size = *x.shape, self.neural_memory_segment_len, self.segment_len, self.num_longterm_mem_tokens, self.attn_window_size
|
|
610
658
|
|
|
611
659
|
# token embedding
|
|
612
660
|
|
|
@@ -640,6 +688,12 @@ class MemoryAsContextTransformer(Module):
|
|
|
640
688
|
block_mask = create_mac_block_mask(seq_len_with_mem, segment_len + num_longterm_mem_tokens, self.num_persist_mem_tokens, self.sliding_window_attn)
|
|
641
689
|
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
642
690
|
|
|
691
|
+
# kv caching
|
|
692
|
+
|
|
693
|
+
is_inferencing = exists(cache)
|
|
694
|
+
cache = iter(default(cache, []))
|
|
695
|
+
next_kv_caches = []
|
|
696
|
+
|
|
643
697
|
# value residual
|
|
644
698
|
|
|
645
699
|
value_residual = None
|
|
@@ -648,23 +702,48 @@ class MemoryAsContextTransformer(Module):
|
|
|
648
702
|
|
|
649
703
|
kv_recon_losses = self.zero
|
|
650
704
|
|
|
705
|
+
# when inferencing, only do one token at a time
|
|
706
|
+
|
|
707
|
+
if is_inferencing:
|
|
708
|
+
x = x[:, -1:]
|
|
709
|
+
|
|
651
710
|
# expand and reduce streams for hyper connections
|
|
652
711
|
|
|
653
712
|
x = self.expand_streams(x)
|
|
654
713
|
|
|
655
|
-
for attn, ff in self.layers:
|
|
714
|
+
for mem, attn, ff in self.layers:
|
|
656
715
|
|
|
657
|
-
|
|
716
|
+
retrieved = None
|
|
717
|
+
attn_out_gates = None
|
|
718
|
+
|
|
719
|
+
# maybe neural memory
|
|
720
|
+
|
|
721
|
+
if exists(mem):
|
|
722
|
+
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
|
|
723
|
+
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
|
724
|
+
|
|
725
|
+
if self.gate_attn_output:
|
|
726
|
+
attn_out_gates = retrieved.sigmoid()
|
|
727
|
+
else:
|
|
728
|
+
seq = retrieved
|
|
729
|
+
|
|
730
|
+
# attention
|
|
731
|
+
|
|
732
|
+
x, (values, next_kv_cache) = attn(
|
|
658
733
|
x,
|
|
659
734
|
value_residual = value_residual,
|
|
660
735
|
disable_flex_attn = disable_flex_attn,
|
|
661
|
-
flex_attn_fn = flex_attn_fn
|
|
736
|
+
flex_attn_fn = flex_attn_fn,
|
|
737
|
+
output_gating = attn_out_gates,
|
|
738
|
+
cache = next(cache, None)
|
|
662
739
|
)
|
|
663
740
|
|
|
664
|
-
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
665
|
-
|
|
666
741
|
value_residual = default(value_residual, values)
|
|
667
742
|
|
|
743
|
+
next_kv_caches.append(next_kv_cache)
|
|
744
|
+
|
|
745
|
+
# feedforward
|
|
746
|
+
|
|
668
747
|
x = ff(x)
|
|
669
748
|
|
|
670
749
|
x = self.reduce_streams(x)
|
|
@@ -687,7 +766,16 @@ class MemoryAsContextTransformer(Module):
|
|
|
687
766
|
if not return_cache:
|
|
688
767
|
return logits
|
|
689
768
|
|
|
690
|
-
|
|
769
|
+
next_kv_caches = stack([stack(kv_cache) for kv_cache in next_kv_caches])
|
|
770
|
+
|
|
771
|
+
# handle kv cache length depending on local attention type
|
|
772
|
+
|
|
773
|
+
next_kv_caches = next_kv_caches[..., -attn_window_size:, :]
|
|
774
|
+
|
|
775
|
+
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
|
|
776
|
+
next_kv_caches = next_kv_caches[..., 0:0, :]
|
|
777
|
+
|
|
778
|
+
return logits, next_kv_caches
|
|
691
779
|
|
|
692
780
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
693
781
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.20
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
|
41
|
+
Requires-Dist: hyper-connections>=0.1.9
|
|
42
42
|
Requires-Dist: ninja
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=Ejq1r3GQQnlT1Fo4McaOOie19t1HjwVlYbD90GLQCYI,22859
|
|
4
|
+
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
+
titans_pytorch-0.1.20.dist-info/METADATA,sha256=Y0TmkfpKQ4LAyhr6SmAGeLHs3H4ZiZ4lg-gevvUDmjI,6340
|
|
6
|
+
titans_pytorch-0.1.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.20.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.20.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=ajsX-djEUzstig5n99yF_NimRzKNfv0MSz-EIV-Fe1A,20393
|
|
4
|
-
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
-
titans_pytorch-0.1.17.dist-info/METADATA,sha256=E9nwWCKZLSqT9Mr85nrJQzinYpKZnkLeexeaYyOIqrU,6340
|
|
6
|
-
titans_pytorch-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|