titans-pytorch 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +36 -58
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.18.dist-info}/METADATA +2 -2
- titans_pytorch-0.1.18.dist-info/RECORD +8 -0
- titans_pytorch-0.1.17.dist-info/RECORD +0 -8
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.18.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.1.17.dist-info → titans_pytorch-0.1.18.dist-info}/licenses/LICENSE +0 -0
|
@@ -217,7 +217,8 @@ class SegmentedAttention(Module):
|
|
|
217
217
|
self,
|
|
218
218
|
seq,
|
|
219
219
|
value_residual = None,
|
|
220
|
-
flex_attn_fn: Callable | None = None
|
|
220
|
+
flex_attn_fn: Callable | None = None,
|
|
221
|
+
output_gating = None
|
|
221
222
|
):
|
|
222
223
|
|
|
223
224
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -267,6 +268,9 @@ class SegmentedAttention(Module):
|
|
|
267
268
|
|
|
268
269
|
out = self.to_out(out)
|
|
269
270
|
|
|
271
|
+
if exists(output_gating):
|
|
272
|
+
out = out * output_gating
|
|
273
|
+
|
|
270
274
|
return out, orig_v
|
|
271
275
|
|
|
272
276
|
def forward(
|
|
@@ -274,10 +278,11 @@ class SegmentedAttention(Module):
|
|
|
274
278
|
seq,
|
|
275
279
|
value_residual = None,
|
|
276
280
|
flex_attn_fn: Callable | None = None,
|
|
277
|
-
disable_flex_attn = False
|
|
281
|
+
disable_flex_attn = False,
|
|
282
|
+
output_gating = None
|
|
278
283
|
):
|
|
279
284
|
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
280
|
-
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
285
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
|
|
281
286
|
|
|
282
287
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
283
288
|
|
|
@@ -361,50 +366,10 @@ class SegmentedAttention(Module):
|
|
|
361
366
|
|
|
362
367
|
out = inverse_segment(out)
|
|
363
368
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
367
|
-
|
|
368
|
-
class NeuralMemoryGatingWrapper(Module):
|
|
369
|
-
def __init__(
|
|
370
|
-
self,
|
|
371
|
-
dim,
|
|
372
|
-
attn: SegmentedAttention,
|
|
373
|
-
neural_mem: NeuralMemory | None = None,
|
|
374
|
-
gate_attn_output = True
|
|
375
|
-
):
|
|
376
|
-
super().__init__()
|
|
377
|
-
self.attn = attn
|
|
378
|
-
self.neural_mem = neural_mem
|
|
379
|
-
self.gate_attn_output = gate_attn_output
|
|
380
|
-
|
|
381
|
-
def forward(
|
|
382
|
-
self,
|
|
383
|
-
seq,
|
|
384
|
-
*args,
|
|
385
|
-
**kwargs
|
|
386
|
-
):
|
|
387
|
-
batch, seq_len = seq.shape[:2]
|
|
388
|
-
mem = self.neural_mem
|
|
389
|
-
|
|
390
|
-
if not exists(mem):
|
|
391
|
-
return self.attn(seq, *args, **kwargs), 0.
|
|
369
|
+
if exists(output_gating):
|
|
370
|
+
out = out * output_gating
|
|
392
371
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
396
|
-
|
|
397
|
-
if not self.gate_attn_output:
|
|
398
|
-
seq = seq + retrieved
|
|
399
|
-
|
|
400
|
-
# attention
|
|
401
|
-
|
|
402
|
-
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
403
|
-
|
|
404
|
-
if self.gate_attn_output:
|
|
405
|
-
attn_out = attn_out * retrieved.sigmoid()
|
|
406
|
-
|
|
407
|
-
return (attn_out, values), kv_aux_loss
|
|
372
|
+
return out, orig_v
|
|
408
373
|
|
|
409
374
|
# MAC transformer
|
|
410
375
|
|
|
@@ -494,16 +459,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
494
459
|
**neural_memory_kwargs
|
|
495
460
|
)
|
|
496
461
|
|
|
497
|
-
attn = NeuralMemoryGatingWrapper(
|
|
498
|
-
dim,
|
|
499
|
-
attn = attn,
|
|
500
|
-
neural_mem = mem,
|
|
501
|
-
gate_attn_output = neural_mem_gate_attn_output
|
|
502
|
-
)
|
|
503
|
-
|
|
504
462
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
505
463
|
|
|
506
464
|
self.layers.append(ModuleList([
|
|
465
|
+
init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
|
|
507
466
|
init_hyper_conn(dim = dim, branch = attn),
|
|
508
467
|
init_hyper_conn(dim = dim, branch = ff)
|
|
509
468
|
]))
|
|
@@ -512,6 +471,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
512
471
|
|
|
513
472
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
|
514
473
|
|
|
474
|
+
# whether to gate the attention output with the retrieved memories
|
|
475
|
+
|
|
476
|
+
self.gate_attn_output = neural_mem_gate_attn_output
|
|
477
|
+
|
|
515
478
|
# auxiliary loss on kv recon
|
|
516
479
|
|
|
517
480
|
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
|
@@ -652,19 +615,34 @@ class MemoryAsContextTransformer(Module):
|
|
|
652
615
|
|
|
653
616
|
x = self.expand_streams(x)
|
|
654
617
|
|
|
655
|
-
for attn, ff in self.layers:
|
|
618
|
+
for mem, attn, ff in self.layers:
|
|
619
|
+
|
|
620
|
+
retrieved = None
|
|
621
|
+
attn_out_gates = None
|
|
622
|
+
|
|
623
|
+
if exists(mem):
|
|
624
|
+
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
|
|
625
|
+
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
|
656
626
|
|
|
657
|
-
|
|
627
|
+
if self.gate_attn_output:
|
|
628
|
+
attn_out_gates = retrieved.sigmoid()
|
|
629
|
+
else:
|
|
630
|
+
seq = retrieved
|
|
631
|
+
|
|
632
|
+
# attention
|
|
633
|
+
|
|
634
|
+
x, values = attn(
|
|
658
635
|
x,
|
|
659
636
|
value_residual = value_residual,
|
|
660
637
|
disable_flex_attn = disable_flex_attn,
|
|
661
|
-
flex_attn_fn = flex_attn_fn
|
|
638
|
+
flex_attn_fn = flex_attn_fn,
|
|
639
|
+
output_gating = attn_out_gates
|
|
662
640
|
)
|
|
663
641
|
|
|
664
|
-
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
665
|
-
|
|
666
642
|
value_residual = default(value_residual, values)
|
|
667
643
|
|
|
644
|
+
# feedforward
|
|
645
|
+
|
|
668
646
|
x = ff(x)
|
|
669
647
|
|
|
670
648
|
x = self.reduce_streams(x)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.18
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
|
41
|
+
Requires-Dist: hyper-connections>=0.1.9
|
|
42
42
|
Requires-Dist: ninja
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=9HHZ-KyWrnIXMG68WdevnuX_yigVoxKeikT0J5yDFY8,19957
|
|
4
|
+
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
+
titans_pytorch-0.1.18.dist-info/METADATA,sha256=w5G49IHPqC1eqCRhAgJYPLHYVmaO8zucI-GGVPemoKg,6340
|
|
6
|
+
titans_pytorch-0.1.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.1.18.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.1.18.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=u0tta_KqhOdfzCEDWT9P4_jejJEK2q1XxhsEzB5MnQU,223
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=ajsX-djEUzstig5n99yF_NimRzKNfv0MSz-EIV-Fe1A,20393
|
|
4
|
-
titans_pytorch/titans.py,sha256=R0e25ly2uTHkHSZEb-9Eqb0DqtFq8wFBB8iH1T6bYVg,22240
|
|
5
|
-
titans_pytorch-0.1.17.dist-info/METADATA,sha256=E9nwWCKZLSqT9Mr85nrJQzinYpKZnkLeexeaYyOIqrU,6340
|
|
6
|
-
titans_pytorch-0.1.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.1.17.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.1.17.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|