titans-pytorch 0.1.15__tar.gz → 0.1.17__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/PKG-INFO +1 -1
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/pyproject.toml +1 -1
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/tests/test_titans.py +20 -3
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/titans_pytorch/mac_transformer.py +41 -5
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/.gitignore +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/LICENSE +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/README.md +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/data/README.md +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/fig1.png +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/fig2.png +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/titans_pytorch/titans.py +0 -0
- {titans_pytorch-0.1.15 → titans_pytorch-0.1.17}/train_mac.py +0 -0
|
@@ -3,7 +3,7 @@ from torch import nn
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from titans_pytorch import NeuralMemory
|
|
6
|
-
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
|
|
6
|
+
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer
|
|
7
7
|
|
|
8
8
|
def exists(v):
|
|
9
9
|
return v is not None
|
|
@@ -92,8 +92,6 @@ def test_mac(
|
|
|
92
92
|
num_longterm_mem_tokens,
|
|
93
93
|
neural_mem_gate_attn_output
|
|
94
94
|
):
|
|
95
|
-
from titans_pytorch.mac_transformer import MemoryAsContextTransformer
|
|
96
|
-
|
|
97
95
|
transformer = MemoryAsContextTransformer(
|
|
98
96
|
num_tokens = 256,
|
|
99
97
|
dim = 256,
|
|
@@ -109,6 +107,25 @@ def test_mac(
|
|
|
109
107
|
logits = transformer(x)
|
|
110
108
|
assert logits.shape == (1, seq_len, 256)
|
|
111
109
|
|
|
110
|
+
def test_mac_sampling():
|
|
111
|
+
transformer = MemoryAsContextTransformer(
|
|
112
|
+
num_tokens = 256,
|
|
113
|
+
dim = 256,
|
|
114
|
+
depth = 2,
|
|
115
|
+
segment_len = 32,
|
|
116
|
+
num_persist_mem_tokens = 4,
|
|
117
|
+
num_longterm_mem_tokens = 16,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
ids = torch.randint(0, 256, (1, 1023))
|
|
121
|
+
|
|
122
|
+
# after much training
|
|
123
|
+
|
|
124
|
+
sampled = transformer.sample(ids[:, :4], 53, use_cache = False, temperature = 0.)
|
|
125
|
+
sampled_with_cache = transformer.sample(ids[:, :4], 53, use_cache = True, temperature = 0.)
|
|
126
|
+
|
|
127
|
+
assert torch.allclose(sampled, sampled_with_cache)
|
|
128
|
+
|
|
112
129
|
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
113
130
|
@pytest.mark.parametrize('sliding', (True, False))
|
|
114
131
|
def test_flex(
|
|
@@ -537,7 +537,8 @@ class MemoryAsContextTransformer(Module):
|
|
|
537
537
|
filter_kwargs: dict = dict(
|
|
538
538
|
min_p = 0.1,
|
|
539
539
|
),
|
|
540
|
-
show_progress = True
|
|
540
|
+
show_progress = True,
|
|
541
|
+
use_cache = False
|
|
541
542
|
):
|
|
542
543
|
was_training = self.training
|
|
543
544
|
self.eval()
|
|
@@ -547,8 +548,37 @@ class MemoryAsContextTransformer(Module):
|
|
|
547
548
|
|
|
548
549
|
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
549
550
|
|
|
551
|
+
# cache for axial pos, attention, and neural memory
|
|
552
|
+
|
|
553
|
+
cache = None
|
|
554
|
+
factorized_pos_emb = None
|
|
555
|
+
|
|
556
|
+
# precompute factorized pos emb
|
|
557
|
+
|
|
558
|
+
if use_cache:
|
|
559
|
+
round_up_seq_len = round_up_multiple(seq_len, self.segment_len)
|
|
560
|
+
longterm_mem_lens = (round_up_seq_len // self.segment_len) * self.num_longterm_mem_tokens
|
|
561
|
+
seq_len_with_mem = round_up_seq_len + longterm_mem_lens
|
|
562
|
+
|
|
563
|
+
axial_dims = self.axial_pos_emb.maybe_derive_outer_dim(seq_len_with_mem, (self.neural_memory_segment_len,))
|
|
564
|
+
|
|
565
|
+
factorized_pos_emb = self.axial_pos_emb(axial_dims, return_factorized = True)
|
|
566
|
+
|
|
567
|
+
# sample
|
|
568
|
+
|
|
550
569
|
for _ in iter_wrap(range(sample_num_times)):
|
|
551
|
-
|
|
570
|
+
|
|
571
|
+
logits, next_cache = self.forward(
|
|
572
|
+
out,
|
|
573
|
+
disable_flex_attn = True,
|
|
574
|
+
cache = cache,
|
|
575
|
+
return_cache = True,
|
|
576
|
+
factorized_pos_emb = factorized_pos_emb
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
if use_cache:
|
|
580
|
+
cache = next_cache
|
|
581
|
+
|
|
552
582
|
logits = logits[:, -1]
|
|
553
583
|
|
|
554
584
|
logits = filter_fn(logits, **filter_kwargs)
|
|
@@ -565,7 +595,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
565
595
|
x,
|
|
566
596
|
return_loss = False,
|
|
567
597
|
return_loss_breakdown = False,
|
|
568
|
-
disable_flex_attn = False
|
|
598
|
+
disable_flex_attn = False,
|
|
599
|
+
cache = None,
|
|
600
|
+
return_cache = False,
|
|
601
|
+
factorized_pos_emb = None
|
|
569
602
|
):
|
|
570
603
|
|
|
571
604
|
if return_loss:
|
|
@@ -593,7 +626,7 @@ class MemoryAsContextTransformer(Module):
|
|
|
593
626
|
# apply axial positional embedding
|
|
594
627
|
# so intra and inter segment can be more easily discerned by the network
|
|
595
628
|
|
|
596
|
-
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,))
|
|
629
|
+
pos_emb = self.axial_pos_emb.forward_with_seq_len(seq_len_with_mem, (neural_mem_segment_len,), factorized = factorized_pos_emb)
|
|
597
630
|
|
|
598
631
|
x = x + pos_emb
|
|
599
632
|
|
|
@@ -651,7 +684,10 @@ class MemoryAsContextTransformer(Module):
|
|
|
651
684
|
logits = self.to_logits(x)
|
|
652
685
|
|
|
653
686
|
if not return_loss:
|
|
654
|
-
|
|
687
|
+
if not return_cache:
|
|
688
|
+
return logits
|
|
689
|
+
|
|
690
|
+
return logits, cache
|
|
655
691
|
|
|
656
692
|
ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)
|
|
657
693
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|