titans-pytorch 0.1.15__tar.gz → 0.1.18__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/PKG-INFO +2 -2
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/pyproject.toml +2 -2
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/tests/test_titans.py +20 -3
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/titans_pytorch/mac_transformer.py +77 -63
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/.gitignore +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/LICENSE +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/README.md +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/data/README.md +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/fig1.png +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/fig2.png +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.18}/train_mac.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.18
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -38,7 +38,7 @@ Requires-Dist: accelerated-scan>=0.2.0
|
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.9
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
40
|
Requires-Dist: einx>=0.3.0
|
|
41
|
-
Requires-Dist: hyper-connections>=0.1.
|
|
41
|
+
Requires-Dist: hyper-connections>=0.1.9
|
|
42
42
|
Requires-Dist: ninja
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.18"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -29,7 +29,7 @@ dependencies = [
|
|
|
29
29
|
"axial_positional_embedding>=0.3.9",
|
|
30
30
|
"einops>=0.8.0",
|
|
31
31
|
"einx>=0.3.0",
|
|
32
|
-
"hyper-connections>=0.1.
|
|
32
|
+
"hyper-connections>=0.1.9",
|
|
33
33
|
"Ninja",
|
|
34
34
|
"rotary-embedding-torch",
|
|
35
35
|
"tensordict",
|
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from titans_pytorch import NeuralMemory
|
|
6
|
-
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
|
|
6
|
+
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
|
|
7
7
|
|
|
8
8
|
def exists(v):
|
|
9
9
|
return v is not None
|
|
@@ -92,8 +92,6 @@ def test_mac(
|
|
|
92
92
|
num_longterm_mem_tokens,
|
|
93
93
|
neural_mem_gate_attn_output
|
|
94
94
|
):
|
|
95
|
-
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
|
96
|
-
|
|
97
95
|
transformer = MemoryAsContextTransformer(
|
|
98
96
|
num_tokens = 256,
|
|
99
97
|
dim = 256,
|
|
@@ -109,6 +107,25 @@ def test_mac(
|
|
|
109
107
|
logits = transformer(x)
|
|
110
108
|
assert logits.shape == (1, seq_len, 256)
|
|
111
109
|
|
|
110
|
+
def test_mac_sampling():
|
|
111
|
+
transformer = MemoryAsContextTransformer(
|
|
112
|
+
num_tokens = 256,
|
|
113
|
+
dim = 256,
|
|
114
|
+
depth = 2,
|
|
115
|
+
segment_len = 32,
|
|
116
|
+
num_persist_mem_tokens = 4,
|
|
117
|
+
num_longterm_mem_tokens = 16,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
ids = torch.randint(0, 256, (1, 1023))
|
|
121
|
+
|
|
122
|
+
# after much training
|
|
123
|
+
|
|
124
|
+
sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
|
|
125
|
+
sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
|
|
126
|
+
|
|
127
|
+
assert torch.allclose(sampled, sampled_with_cache)
|
|
128
|
+
|
|
112
129
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
113
130
|
@pytest.mark.parametrize('sliding', (True, False))
|
|
114
131
|
def test_flex(
|
|
@@ -217,7 +217,8 @@ class SegmentedAttention(Module):
|
|
|
217
217
|
self,
|
|
218
218
|
seq,
|
|
219
219
|
value_residual = None,
|
|
220
|
-
flex_attn_fn: Callable | None = None
|
|
220
|
+
flex_attn_fn: Callable | None = None,
|
|
221
|
+
output_gating = None
|
|
221
222
|
):
|
|
222
223
|
|
|
223
224
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
@@ -267,6 +268,9 @@ class SegmentedAttention(Module):
|
|
|
267
268
|
|
|
268
269
|
out = self.to_out(out)
|
|
269
270
|
|
|
271
|
+
if exists(output_gating):
|
|
272
|
+
out = out * output_gating
|
|
273
|
+
|
|
270
274
|
return out, orig_v
|
|
271
275
|
|
|
272
276
|
def forward(
|
|
@@ -274,10 +278,11 @@ class SegmentedAttention(Module):
|
|
|
274
278
|
seq,
|
|
275
279
|
value_residual = None,
|
|
276
280
|
flex_attn_fn: Callable | None = None,
|
|
277
|
-
disable_flex_attn = False
|
|
281
|
+
disable_flex_attn = False,
|
|
282
|
+
output_gating = None
|
|
278
283
|
):
|
|
279
284
|
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
|
|
280
|
-
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
285
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)
|
|
281
286
|
|
|
282
287
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
283
288
|
|
|
@@ -361,50 +366,10 @@ class SegmentedAttention(Module):
|
|
|
361
366
|
|
|
362
367
|
out = inverse_segment(out)
|
|
363
368
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
# Attention + Neural Memory gating configuration, as depicted in Figure 2
|
|
367
|
-
|
|
368
|
-
class NeuralMemoryGatingWrapper(Module):
|
|
369
|
-
def __init__(
|
|
370
|
-
self,
|
|
371
|
-
dim,
|
|
372
|
-
attn: SegmentedAttention,
|
|
373
|
-
neural_mem: NeuralMemory | None = None,
|
|
374
|
-
gate_attn_output = True
|
|
375
|
-
):
|
|
376
|
-
super().__init__()
|
|
377
|
-
self.attn = attn
|
|
378
|
-
self.neural_mem = neural_mem
|
|
379
|
-
self.gate_attn_output = gate_attn_output
|
|
380
|
-
|
|
381
|
-
def forward(
|
|
382
|
-
self,
|
|
383
|
-
seq,
|
|
384
|
-
*args,
|
|
385
|
-
**kwargs
|
|
386
|
-
):
|
|
387
|
-
batch, seq_len = seq.shape[:2]
|
|
388
|
-
mem = self.neural_mem
|
|
389
|
-
|
|
390
|
-
if not exists(mem):
|
|
391
|
-
return self.attn(seq, *args, **kwargs), 0.
|
|
392
|
-
|
|
393
|
-
# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory
|
|
394
|
-
|
|
395
|
-
retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)
|
|
396
|
-
|
|
397
|
-
if not self.gate_attn_output:
|
|
398
|
-
seq = seq + retrieved
|
|
399
|
-
|
|
400
|
-
# attention
|
|
401
|
-
|
|
402
|
-
attn_out, values = self.attn(seq, *args, **kwargs)
|
|
403
|
-
|
|
404
|
-
if self.gate_attn_output:
|
|
405
|
-
attn_out = attn_out * retrieved.sigmoid()
|
|
369
|
+
if exists(output_gating):
|
|
370
|
+
out = out * output_gating
|
|
406
371
|
|
|
407
|
-
return
|
|
372
|
+
return out, orig_v
|
|
408
373
|
|
|
409
374
|
# MAC transformer
|
|
410
375
|
|
|
@@ -494,16 +459,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
494
459
|
**neural_memory_kwargs
|
|
495
460
|
)
|
|
496
461
|
|
|
497
|
-
attn = NeuralMemoryGatingWrapper(
|
|
498
|
-
dim,
|
|
499
|
-
attn = attn,
|
|
500
|
-
neural_mem = mem,
|
|
501
|
-
gate_attn_output = neural_mem_gate_attn_output
|
|
502
|
-
)
|
|
503
|
-
|
|
504
462
|
ff = FeedForward(dim = dim, mult = ff_mult)
|
|
505
463
|
|
|
506
464
|
self.layers.append(ModuleList([
|
|
465
|
+
init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
|
|
507
466
|
init_hyper_conn(dim = dim, branch = attn),
|
|
508
467
|
init_hyper_conn(dim = dim, branch = ff)
|
|
509
468
|
]))
|
|
@@ -512,6 +471,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
512
471
|
|
|
513
472
|
self.to_logits = LinearNoBias(dim, num_tokens)
|
|
514
473
|
|
|
474
|
+
# whether to gate the attention output with the retrieved memories
|
|
475
|
+
|
|
476
|
+
self.gate_attn_output = neural_mem_gate_attn_output
|
|
477
|
+
|
|
515
478
|
# auxiliary loss on kv recon
|
|
516
479
|
|
|
517
480
|
self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
|
|
@@ -537,7 +500,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
537
500
|
filter_kwargs: dict = dict(
|
|
538
501
|
min_p = 0.1,
|
|
539
502
|
),
|
|
540
|
-
show_progress = True
|
|
503
|
+
show_progress = True,
|
|
504
|
+
use_cache = False
|
|
541
505
|
):
|
|
542
506
|
was_training = self.training
|
|
543
507
|
self.eval()
|
|
@@ -547,8 +511,37 @@ class MemoryAsContextTransformer(Module):
|
|
|
547
511
|
|
|
548
512
|
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
549
513
|
|
|
514
|
+
# cache for axial pos, attention, and neural memory
|
|
515
|
+
|
|
516
|
+
cache = None
|
|
517
|
+
factorized_pos_emb = None
|
|
518
|
+
|
|
519
|
+
# precompute factorized pos emb
|
|
520
|
+
|
|
521
|
+
if use_cache:
|
|
522
|
+
round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
|
|
523
|
+
longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
|
|
524
|
+
seq_len_with_mem = round_up_seq_len + longterm_mem_lens
|
|
525
|
+
|
|
526
|
+
axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
|
|
527
|
+
|
|
528
|
+
factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
|
|
529
|
+
|
|
530
|
+
# sample
|
|
531
|
+
|
|
550
532
|
for _ in iter_wrap(range(sample_num_times)):
|
|
551
|
-
|
|
533
|
+
|
|
534
|
+
logits, next_cache = self.forward(
|
|
535
|
+
out,
|
|
536
|
+
disable_flex_attn = True,
|
|
537
|
+
cache = cache,
|
|
538
|
+
return_cache = True,
|
|
539
|
+
factorized_pos_emb = factorized_pos_emb
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
if use_cache:
|
|
543
|
+
cache = next_cache
|
|
544
|
+
|
|
552
545
|
logits = logits[:, -1]
|
|
553
546
|
|
|
554
547
|
logits = filter_fn(logits, **filter_kwargs)
|
|
@@ -565,7 +558,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
565
558
|
x,
|
|
566
559
|
return_loss = False,
|
|
567
560
|
return_loss_breakdown = False,
|
|
568
|
-
disable_flex_attn = False
|
|
561
|
+
disable_flex_attn = False,
|
|
562
|
+
cache = None,
|
|
563
|
+
return_cache = False,
|
|
564
|
+
factorized_pos_emb = None
|
|
569
565
|
):
|
|
570
566
|
|
|
571
567
|
if return_loss:
|
|
@@ -593,7 +589,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
593
589
|
# apply axial positional embedding
|
|
594
590
|
# so intra and inter segment can be more easily discerned by the network
|
|
595
591
|
|
|
596
|
-
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
|
|
592
|
+
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
|
|
597
593
|
|
|
598
594
|
x = x + pos_emb
|
|
599
595
|
|
|
@@ -619,19 +615,34 @@ class MemoryAsContextTransformer(Module):
|
|
|
619
615
|
|
|
620
616
|
x = self.expand_streams(x)
|
|
621
617
|
|
|
622
|
-
for attn, ff in self.layers:
|
|
618
|
+
for mem, attn, ff in self.layers:
|
|
619
|
+
|
|
620
|
+
retrieved = None
|
|
621
|
+
attn_out_gates = None
|
|
622
|
+
|
|
623
|
+
if exists(mem):
|
|
624
|
+
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
|
|
625
|
+
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss
|
|
623
626
|
|
|
624
|
-
|
|
627
|
+
if self.gate_attn_output:
|
|
628
|
+
attn_out_gates = retrieved.sigmoid()
|
|
629
|
+
else:
|
|
630
|
+
seq = retrieved
|
|
631
|
+
|
|
632
|
+
# attention
|
|
633
|
+
|
|
634
|
+
x, values = attn(
|
|
625
635
|
x,
|
|
626
636
|
value_residual = value_residual,
|
|
627
637
|
disable_flex_attn = disable_flex_attn,
|
|
628
|
-
flex_attn_fn = flex_attn_fn
|
|
638
|
+
flex_attn_fn = flex_attn_fn,
|
|
639
|
+
output_gating = attn_out_gates
|
|
629
640
|
)
|
|
630
641
|
|
|
631
|
-
kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss
|
|
632
|
-
|
|
633
642
|
value_residual = default(value_residual, values)
|
|
634
643
|
|
|
644
|
+
# feedforward
|
|
645
|
+
|
|
635
646
|
x = ff(x)
|
|
636
647
|
|
|
637
648
|
x = self.reduce_streams(x)
|
|
@@ -651,7 +662,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
651
662
|
logits = self.to_logits(x)
|
|
652
663
|
|
|
653
664
|
if not return_loss:
|
|
654
|
-
|
|
665
|
+
if not return_cache:
|
|
666
|
+
return logits
|
|
667
|
+
|
|
668
|
+
return logits, cache
|
|
655
669
|
|
|
656
670
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
657
671
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|