titans-pytorch 0.0.50__tar.gz → 0.0.51__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/PKG-INFO +10 -1
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/README.md +9 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/pyproject.toml +1 -1
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/tests/test_titans.py +6 -1
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/titans_pytorch/mac_transformer.py +33 -1
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/titans_pytorch/titans.py +14 -5
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/train_mac.py +52 -5
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/.gitignore +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/LICENSE +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/data/README.md +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/fig1.png +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/fig2.png +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.50 → titans_pytorch-0.0.51}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.51
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -135,3 +135,12 @@ $ python train_mac.py
|
|
|
135
135
|
year = {2024}
|
|
136
136
|
}
|
|
137
137
|
```
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@inproceedings{Yang2024GatedDN,
|
|
141
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
142
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
143
|
+
year = {2024},
|
|
144
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
145
|
+
}
|
|
146
|
+
```
|
|
@@ -82,3 +82,12 @@ $ python train_mac.py
|
|
|
82
82
|
year = {2024}
|
|
83
83
|
}
|
|
84
84
|
```
|
|
85
|
+
|
|
86
|
+
```bibtex
|
|
87
|
+
@inproceedings{Yang2024GatedDN,
|
|
88
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
89
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
90
|
+
year = {2024},
|
|
91
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
92
|
+
}
|
|
93
|
+
```
|
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
2
4
|
import pytest
|
|
3
5
|
from titans_pytorch import NeuralMemory
|
|
4
6
|
|
|
5
7
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
8
|
+
@pytest.mark.parametrize('silu', (False, True))
|
|
6
9
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
7
10
|
def test_titans(
|
|
8
11
|
seq_len,
|
|
9
|
-
|
|
12
|
+
silu,
|
|
13
|
+
max_grad_norm,
|
|
10
14
|
):
|
|
11
15
|
mem = NeuralMemory(
|
|
12
16
|
dim = 384,
|
|
13
17
|
chunk_size = 64,
|
|
18
|
+
activation = nn.SiLU() if silu else None,
|
|
14
19
|
max_grad_norm = max_grad_norm
|
|
15
20
|
)
|
|
16
21
|
|
|
@@ -7,16 +7,48 @@ from torch import nn, cat
|
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
|
9
9
|
|
|
10
|
+
# flex attention
|
|
11
|
+
# https://pytorch.org/blog/flexattention/
|
|
12
|
+
|
|
13
|
+
flex_attention = None
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
flex_attention = torch.compile(flex_attention)
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
23
|
+
|
|
24
|
+
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
25
|
+
is_persist_mem = kv_idx < persist_mem_len
|
|
26
|
+
causal_mask = q_idx >= (kv_idx - is_persist_mem)
|
|
27
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
|
|
28
|
+
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
29
|
+
|
|
30
|
+
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
31
|
+
return block_mask
|
|
32
|
+
|
|
33
|
+
# einstein notation related
|
|
34
|
+
|
|
10
35
|
from einops import einsum, repeat, rearrange, pack, unpack
|
|
11
36
|
from einops.layers.torch import Rearrange
|
|
12
37
|
|
|
13
|
-
|
|
38
|
+
# b - batch
|
|
39
|
+
# n - sequence
|
|
40
|
+
# h - heads
|
|
41
|
+
# d - feature dimension
|
|
14
42
|
|
|
15
43
|
# absolute and relative positions
|
|
16
44
|
|
|
17
45
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
|
18
46
|
from rotary_embedding_torch import RotaryEmbedding
|
|
47
|
+
|
|
48
|
+
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
|
49
|
+
|
|
19
50
|
from x_transformers.attend import Attend
|
|
51
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
20
52
|
|
|
21
53
|
# proposed neural memory
|
|
22
54
|
|
|
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
|
|
|
56
56
|
|
|
57
57
|
return packed, inverse
|
|
58
58
|
|
|
59
|
+
def Sequential(*modules):
|
|
60
|
+
modules = [*filter(exists, modules)]
|
|
61
|
+
|
|
62
|
+
if len(modules) == 0:
|
|
63
|
+
return nn.Identity()
|
|
64
|
+
|
|
65
|
+
if len(modules) == 1:
|
|
66
|
+
return modules[0]
|
|
67
|
+
|
|
68
|
+
return nn.Sequential(*modules)
|
|
69
|
+
|
|
59
70
|
# softclamping gradients
|
|
60
71
|
|
|
61
72
|
def softclamp_max(t, max_value):
|
|
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
|
|
|
124
135
|
])
|
|
125
136
|
|
|
126
137
|
def forward(self, x):
|
|
127
|
-
|
|
128
|
-
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
129
|
-
|
|
130
138
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
131
139
|
|
|
132
140
|
q = F.normalize(x @ wq, dim = -1)
|
|
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
|
|
|
168
176
|
post_rmsnorm = True,
|
|
169
177
|
max_grad_norm: float | None = None,
|
|
170
178
|
use_accelerated_scan = False,
|
|
179
|
+
activation: Module | None = None,
|
|
171
180
|
default_model_kwargs: dict = dict(
|
|
172
181
|
depth = 2
|
|
173
182
|
)
|
|
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
|
|
|
225
234
|
|
|
226
235
|
# queries for retrieving from the model
|
|
227
236
|
|
|
228
|
-
self.to_queries = LinearNoBias(dim, dim_inner)
|
|
237
|
+
self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
|
|
229
238
|
|
|
230
239
|
# keys and values for storing to the model
|
|
231
240
|
|
|
232
|
-
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
|
241
|
+
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
|
233
242
|
self.store_memory_loss_fn = store_memory_loss_fn
|
|
234
243
|
|
|
235
244
|
# empty memory embed
|
|
@@ -4,7 +4,7 @@ import gzip
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from torch import nn
|
|
7
|
+
from torch import nn, Tensor
|
|
8
8
|
from torch.optim import Adam
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
from torch.utils.data import DataLoader, Dataset
|
|
@@ -19,12 +19,13 @@ GRADIENT_ACCUMULATE_EVERY = 4
|
|
|
19
19
|
LEARNING_RATE = 2e-4
|
|
20
20
|
VALIDATE_EVERY = 100
|
|
21
21
|
GENERATE_EVERY = 500
|
|
22
|
+
PRIME_LENGTH = 100
|
|
22
23
|
GENERATE_LENGTH = 512
|
|
23
|
-
SHOULD_GENERATE =
|
|
24
|
+
SHOULD_GENERATE = True
|
|
24
25
|
SEQ_LEN = 512
|
|
25
26
|
|
|
26
27
|
PROJECT_NAME = 'titans-mac-transformer'
|
|
27
|
-
WANDB_ONLINE =
|
|
28
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
28
29
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
30
|
NUM_PERSIST_MEM = 4
|
|
30
31
|
NUM_LONGTERM_MEM = 4
|
|
@@ -53,6 +54,52 @@ def decode_token(token):
|
|
|
53
54
|
def decode_tokens(tokens):
|
|
54
55
|
return ''.join(list(map(decode_token, tokens)))
|
|
55
56
|
|
|
57
|
+
# sampling helpers
|
|
58
|
+
|
|
59
|
+
def log(t, eps = 1e-20):
|
|
60
|
+
return torch.log(t.clamp(min = eps))
|
|
61
|
+
|
|
62
|
+
def gumbel_noise(t):
|
|
63
|
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
64
|
+
return -log(-log(noise))
|
|
65
|
+
|
|
66
|
+
def gumbel_sample(t, temperature = 1., keepdim = True):
|
|
67
|
+
if temperature <= 0.:
|
|
68
|
+
return t.argmax(dim = dim, keepdim = keepdim)
|
|
69
|
+
|
|
70
|
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
|
|
71
|
+
|
|
72
|
+
# min_p
|
|
73
|
+
# https://arxiv.org/abs/2407.01082
|
|
74
|
+
|
|
75
|
+
def min_p_filter(logits, min_p = 0.1):
|
|
76
|
+
probs = logits.softmax(dim = -1)
|
|
77
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
78
|
+
limit = min_p * max_probs
|
|
79
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
|
80
|
+
|
|
81
|
+
def base_decoding(
|
|
82
|
+
net,
|
|
83
|
+
prompt: Tensor,
|
|
84
|
+
seq_len: int,
|
|
85
|
+
temperature = 1.5,
|
|
86
|
+
min_p = 1e-1,
|
|
87
|
+
filter_thres = 0.9,
|
|
88
|
+
):
|
|
89
|
+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
90
|
+
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
91
|
+
|
|
92
|
+
for _ in tqdm.tqdm(range(sample_num_times)):
|
|
93
|
+
logits = net(out)
|
|
94
|
+
logits = logits[:, -1]
|
|
95
|
+
|
|
96
|
+
logits = min_p_filter(logits, min_p = min_p)
|
|
97
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
|
98
|
+
|
|
99
|
+
out = torch.cat((out, sample), dim = -1)
|
|
100
|
+
|
|
101
|
+
return out[..., prompt_seq_len:]
|
|
102
|
+
|
|
56
103
|
# instantiate memory-as-context transformer
|
|
57
104
|
|
|
58
105
|
model = MemoryAsContextTransformer(
|
|
@@ -127,10 +174,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
127
174
|
|
|
128
175
|
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
129
176
|
model.eval()
|
|
130
|
-
inp = random.choice(val_dataset)[
|
|
177
|
+
inp = random.choice(val_dataset)[:PRIME_LENGTH]
|
|
131
178
|
prime = decode_tokens(inp)
|
|
132
179
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
133
180
|
|
|
134
|
-
sample = model
|
|
181
|
+
sample = base_decoding(model, inp[None, ...], GENERATE_LENGTH)
|
|
135
182
|
output_str = decode_tokens(sample[0])
|
|
136
183
|
print(output_str)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|