titans-pytorch 0.1.0__tar.gz → 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/PKG-INFO +8 -3
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/README.md +6 -1
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/pyproject.toml +2 -2
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/tests/test_titans.py +33 -2
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/titans_pytorch/mac_transformer.py +58 -6
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/titans_pytorch/titans.py +6 -1
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/train_mac.py +1 -47
- titans_pytorch-0.1.0/assert_flex.py +0 -18
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/.gitignore +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/LICENSE +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/data/README.md +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/fig1.png +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/fig2.png +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.1.0 → titans_pytorch-0.1.2}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -43,9 +43,9 @@ Requires-Dist: ninja
|
|
|
43
43
|
Requires-Dist: rotary-embedding-torch
|
|
44
44
|
Requires-Dist: tensordict
|
|
45
45
|
Requires-Dist: torch>=2.2
|
|
46
|
+
Requires-Dist: tqdm
|
|
46
47
|
Requires-Dist: x-transformers
|
|
47
48
|
Provides-Extra: examples
|
|
48
|
-
Requires-Dist: tqdm; extra == 'examples'
|
|
49
49
|
Requires-Dist: wandb; extra == 'examples'
|
|
50
50
|
Provides-Extra: test
|
|
51
51
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -104,7 +104,12 @@ transformer = MemoryAsContextTransformer(
|
|
|
104
104
|
|
|
105
105
|
token_ids = torch.randint(0, 256, (1, 1023))
|
|
106
106
|
|
|
107
|
-
|
|
107
|
+
loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
|
|
108
|
+
loss.backward()
|
|
109
|
+
|
|
110
|
+
# after much training
|
|
111
|
+
|
|
112
|
+
sampled = transformer.sample(token_ids[:, :4], 512)
|
|
108
113
|
```
|
|
109
114
|
|
|
110
115
|
## Experiments
|
|
@@ -51,7 +51,12 @@ transformer = MemoryAsContextTransformer(
|
|
|
51
51
|
|
|
52
52
|
token_ids = torch.randint(0, 256, (1, 1023))
|
|
53
53
|
|
|
54
|
-
|
|
54
|
+
loss = transformer(token_ids, return_loss = True) # (1, 1023, 256)
|
|
55
|
+
loss.backward()
|
|
56
|
+
|
|
57
|
+
# after much training
|
|
58
|
+
|
|
59
|
+
sampled = transformer.sample(token_ids[:, :4], 512)
|
|
55
60
|
```
|
|
56
61
|
|
|
57
62
|
## Experiments
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.2"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -34,6 +34,7 @@ dependencies = [
|
|
|
34
34
|
"rotary-embedding-torch",
|
|
35
35
|
"tensordict",
|
|
36
36
|
"torch>=2.2",
|
|
37
|
+
"tqdm",
|
|
37
38
|
"x-transformers"
|
|
38
39
|
]
|
|
39
40
|
|
|
@@ -44,7 +45,6 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
|
44
45
|
[project.optional-dependencies]
|
|
45
46
|
|
|
46
47
|
examples = [
|
|
47
|
-
"tqdm",
|
|
48
48
|
"wandb"
|
|
49
49
|
]
|
|
50
50
|
|
|
@@ -3,6 +3,10 @@ from torch import nn
|
|
|
3
3
|
|
|
4
4
|
import pytest
|
|
5
5
|
from titans_pytorch import NeuralMemory
|
|
6
|
+
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention
|
|
7
|
+
|
|
8
|
+
def exists(v):
|
|
9
|
+
return v is not None
|
|
6
10
|
|
|
7
11
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
8
12
|
@pytest.mark.parametrize('silu', (False, True))
|
|
@@ -46,10 +50,12 @@ def test_titans_attn_memory():
|
|
|
46
50
|
|
|
47
51
|
assert seq.shape == retrieved.shape
|
|
48
52
|
|
|
53
|
+
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
49
54
|
@pytest.mark.parametrize('num_persist_mem_tokens', (0, 16))
|
|
50
55
|
@pytest.mark.parametrize('num_longterm_mem_tokens', (0, 16))
|
|
51
56
|
@pytest.mark.parametrize('neural_mem_gate_attn_output', (False, True))
|
|
52
57
|
def test_mac(
|
|
58
|
+
seq_len,
|
|
53
59
|
num_persist_mem_tokens,
|
|
54
60
|
num_longterm_mem_tokens,
|
|
55
61
|
neural_mem_gate_attn_output
|
|
@@ -66,7 +72,32 @@ def test_mac(
|
|
|
66
72
|
neural_mem_gate_attn_output = neural_mem_gate_attn_output
|
|
67
73
|
)
|
|
68
74
|
|
|
69
|
-
x = torch.randint(0, 256, (1,
|
|
75
|
+
x = torch.randint(0, 256, (1, seq_len))
|
|
70
76
|
|
|
71
77
|
logits = transformer(x)
|
|
72
|
-
assert logits.shape == (1,
|
|
78
|
+
assert logits.shape == (1, seq_len, 256)
|
|
79
|
+
|
|
80
|
+
@pytest.mark.parametrize('seq_len', (1023, 17))
|
|
81
|
+
@pytest.mark.parametrize('sliding', (True, False))
|
|
82
|
+
def test_flex(
|
|
83
|
+
seq_len,
|
|
84
|
+
sliding
|
|
85
|
+
):
|
|
86
|
+
if not (torch.cuda.is_available() and exists(flex_attention)):
|
|
87
|
+
pytest.skip()
|
|
88
|
+
|
|
89
|
+
attn = SegmentedAttention(
|
|
90
|
+
dim = 512,
|
|
91
|
+
segment_len = 32,
|
|
92
|
+
num_persist_mem_tokens = 1,
|
|
93
|
+
num_longterm_mem_tokens = 1,
|
|
94
|
+
use_flex_attn = True,
|
|
95
|
+
sliding = sliding
|
|
96
|
+
).cuda()
|
|
97
|
+
|
|
98
|
+
seq = torch.randn(1, seq_len, 512).cuda()
|
|
99
|
+
|
|
100
|
+
out_flex, _ = attn(seq)
|
|
101
|
+
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
|
102
|
+
|
|
103
|
+
assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
|
|
@@ -3,6 +3,8 @@ from typing import Callable
|
|
|
3
3
|
from math import ceil
|
|
4
4
|
from functools import partial
|
|
5
5
|
|
|
6
|
+
import tqdm
|
|
7
|
+
|
|
6
8
|
import torch
|
|
7
9
|
from torch import nn, cat
|
|
8
10
|
import torch.nn.functional as F
|
|
@@ -88,12 +90,6 @@ def pad_at_dim(t, pad, dim = -1, value = 0.):
|
|
|
88
90
|
|
|
89
91
|
def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
90
92
|
batch, seq_len = seq.shape[:2]
|
|
91
|
-
|
|
92
|
-
need_segment = seq_len >= segment_len
|
|
93
|
-
|
|
94
|
-
if not need_segment:
|
|
95
|
-
return seq, identity
|
|
96
|
-
|
|
97
93
|
next_seq_len_mult = round_up_multiple(seq_len, segment_len)
|
|
98
94
|
|
|
99
95
|
padding = next_seq_len_mult - seq_len
|
|
@@ -116,6 +112,29 @@ def pad_and_segment_with_inverse(seq, segment_len, fold_into_batch = True):
|
|
|
116
112
|
|
|
117
113
|
return seq, inverse
|
|
118
114
|
|
|
115
|
+
# sampling related
|
|
116
|
+
|
|
117
|
+
def log(t, eps = 1e-20):
|
|
118
|
+
return torch.log(t.clamp(min = eps))
|
|
119
|
+
|
|
120
|
+
def gumbel_noise(t):
|
|
121
|
+
noise = torch.rand_like(t)
|
|
122
|
+
return -log(-log(noise))
|
|
123
|
+
|
|
124
|
+
def gumbel_sample(t, temperature = 1.):
|
|
125
|
+
if temperature > 0.:
|
|
126
|
+
t = t / temperature + gumbel_noise(t)
|
|
127
|
+
return t.argmax(dim = -1, keepdim = True)
|
|
128
|
+
|
|
129
|
+
# min_p
|
|
130
|
+
# https://arxiv.org/abs/2407.01082
|
|
131
|
+
|
|
132
|
+
def min_p_filter(logits, min_p = 0.1):
|
|
133
|
+
probs = logits.softmax(dim = -1)
|
|
134
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
135
|
+
limit = min_p * max_probs
|
|
136
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
|
137
|
+
|
|
119
138
|
# feedforward and attention
|
|
120
139
|
|
|
121
140
|
class GEGLU(Module):
|
|
@@ -500,6 +519,39 @@ class MemoryAsContextTransformer(Module):
|
|
|
500
519
|
self.segment_len = segment_len
|
|
501
520
|
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
502
521
|
|
|
522
|
+
@torch.no_grad()
|
|
523
|
+
def sample(
|
|
524
|
+
self,
|
|
525
|
+
prompt: Tensor,
|
|
526
|
+
seq_len: int,
|
|
527
|
+
temperature = 1.5,
|
|
528
|
+
filter_fn: Callable = min_p_filter,
|
|
529
|
+
filter_kwargs: dict = dict(
|
|
530
|
+
min_p = 0.1,
|
|
531
|
+
),
|
|
532
|
+
show_progress = True
|
|
533
|
+
):
|
|
534
|
+
was_training = self.training
|
|
535
|
+
self.eval()
|
|
536
|
+
|
|
537
|
+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
538
|
+
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
539
|
+
|
|
540
|
+
iter_wrap = tqdm.tqdm if show_progress else identity
|
|
541
|
+
|
|
542
|
+
for _ in iter_wrap(range(sample_num_times)):
|
|
543
|
+
logits = self.forward(out, disable_flex_attn = True)
|
|
544
|
+
logits = logits[:, -1]
|
|
545
|
+
|
|
546
|
+
logits = filter_fn(logits, **filter_kwargs)
|
|
547
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
|
548
|
+
|
|
549
|
+
out = torch.cat((out, sample), dim = -1)
|
|
550
|
+
|
|
551
|
+
self.train(was_training)
|
|
552
|
+
|
|
553
|
+
return out[..., prompt_seq_len:]
|
|
554
|
+
|
|
503
555
|
def forward(
|
|
504
556
|
self,
|
|
505
557
|
x,
|
|
@@ -592,7 +592,12 @@ class NeuralMemory(Module):
|
|
|
592
592
|
batch, seq_len = seq.shape[:2]
|
|
593
593
|
|
|
594
594
|
if seq_len < self.chunk_size:
|
|
595
|
-
|
|
595
|
+
out = self.init_empty_memory_embed(batch, seq_len)
|
|
596
|
+
|
|
597
|
+
if not return_aux_kv_loss:
|
|
598
|
+
return out
|
|
599
|
+
|
|
600
|
+
return out, self.zero
|
|
596
601
|
|
|
597
602
|
if exists(past_state):
|
|
598
603
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -68,52 +68,6 @@ def decode_token(token):
|
|
|
68
68
|
def decode_tokens(tokens):
|
|
69
69
|
return ''.join(list(map(decode_token, tokens)))
|
|
70
70
|
|
|
71
|
-
# sampling helpers
|
|
72
|
-
|
|
73
|
-
def log(t, eps = 1e-20):
|
|
74
|
-
return torch.log(t.clamp(min = eps))
|
|
75
|
-
|
|
76
|
-
def gumbel_noise(t):
|
|
77
|
-
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
78
|
-
return -log(-log(noise))
|
|
79
|
-
|
|
80
|
-
def gumbel_sample(t, temperature = 1., keepdim = True):
|
|
81
|
-
if temperature <= 0.:
|
|
82
|
-
return t.argmax(dim = dim, keepdim = keepdim)
|
|
83
|
-
|
|
84
|
-
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
|
|
85
|
-
|
|
86
|
-
# min_p
|
|
87
|
-
# https://arxiv.org/abs/2407.01082
|
|
88
|
-
|
|
89
|
-
def min_p_filter(logits, min_p = 0.1):
|
|
90
|
-
probs = logits.softmax(dim = -1)
|
|
91
|
-
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
92
|
-
limit = min_p * max_probs
|
|
93
|
-
return torch.where(probs < limit, float('-inf'), logits)
|
|
94
|
-
|
|
95
|
-
def base_decoding(
|
|
96
|
-
net,
|
|
97
|
-
prompt: Tensor,
|
|
98
|
-
seq_len: int,
|
|
99
|
-
temperature = 1.5,
|
|
100
|
-
min_p = 1e-1,
|
|
101
|
-
filter_thres = 0.9,
|
|
102
|
-
):
|
|
103
|
-
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
104
|
-
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
105
|
-
|
|
106
|
-
for _ in tqdm.tqdm(range(sample_num_times)):
|
|
107
|
-
logits = net(out, disable_flex_attn = True)
|
|
108
|
-
logits = logits[:, -1]
|
|
109
|
-
|
|
110
|
-
logits = min_p_filter(logits, min_p = min_p)
|
|
111
|
-
sample = gumbel_sample(logits, temperature = temperature)
|
|
112
|
-
|
|
113
|
-
out = torch.cat((out, sample), dim = -1)
|
|
114
|
-
|
|
115
|
-
return out[..., prompt_seq_len:]
|
|
116
|
-
|
|
117
71
|
# instantiate memory-as-context transformer
|
|
118
72
|
|
|
119
73
|
model = MemoryAsContextTransformer(
|
|
@@ -197,6 +151,6 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
197
151
|
prime = decode_tokens(inp)
|
|
198
152
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
199
153
|
|
|
200
|
-
sample =
|
|
154
|
+
sample = model.sample(inp[None, ...], GENERATE_LENGTH)
|
|
201
155
|
output_str = decode_tokens(sample[0])
|
|
202
156
|
print(output_str)
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from titans_pytorch.mac_transformer import SegmentedAttention
|
|
3
|
-
|
|
4
|
-
attn = SegmentedAttention(
|
|
5
|
-
dim = 512,
|
|
6
|
-
segment_len = 32,
|
|
7
|
-
num_persist_mem_tokens = 1,
|
|
8
|
-
num_longterm_mem_tokens = 1,
|
|
9
|
-
use_flex_attn = True,
|
|
10
|
-
sliding = False
|
|
11
|
-
).cuda()
|
|
12
|
-
|
|
13
|
-
seq = torch.randn(1, 1019, 512).cuda()
|
|
14
|
-
|
|
15
|
-
out_flex, _ = attn(seq)
|
|
16
|
-
out_non_flex, _ = attn(seq, disable_flex_attn = True)
|
|
17
|
-
|
|
18
|
-
assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|