titans-pytorch 0.0.49__tar.gz → 0.0.51__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/PKG-INFO +10 -1
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/README.md +9 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/pyproject.toml +1 -1
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/tests/test_titans.py +6 -1
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/titans_pytorch/mac_transformer.py +33 -1
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/titans_pytorch/titans.py +19 -6
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/train_mac.py +54 -7
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/.gitignore +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/LICENSE +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/data/README.md +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/fig1.png +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/fig2.png +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.49 → titans_pytorch-0.0.51}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.51
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -135,3 +135,12 @@ $ python train_mac.py
|
|
|
135
135
|
year = {2024}
|
|
136
136
|
}
|
|
137
137
|
```
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@inproceedings{Yang2024GatedDN,
|
|
141
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
142
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
143
|
+
year = {2024},
|
|
144
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
145
|
+
}
|
|
146
|
+
```
|
|
@@ -82,3 +82,12 @@ $ python train_mac.py
|
|
|
82
82
|
year = {2024}
|
|
83
83
|
}
|
|
84
84
|
```
|
|
85
|
+
|
|
86
|
+
```bibtex
|
|
87
|
+
@inproceedings{Yang2024GatedDN,
|
|
88
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
89
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
90
|
+
year = {2024},
|
|
91
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
92
|
+
}
|
|
93
|
+
```
|
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
|
|
2
4
|
import pytest
|
|
3
5
|
from titans_pytorch import NeuralMemory
|
|
4
6
|
|
|
5
7
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
8
|
+
@pytest.mark.parametrize('silu', (False, True))
|
|
6
9
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
7
10
|
def test_titans(
|
|
8
11
|
seq_len,
|
|
9
|
-
|
|
12
|
+
silu,
|
|
13
|
+
max_grad_norm,
|
|
10
14
|
):
|
|
11
15
|
mem = NeuralMemory(
|
|
12
16
|
dim = 384,
|
|
13
17
|
chunk_size = 64,
|
|
18
|
+
activation = nn.SiLU() if silu else None,
|
|
14
19
|
max_grad_norm = max_grad_norm
|
|
15
20
|
)
|
|
16
21
|
|
|
@@ -7,16 +7,48 @@ from torch import nn, cat
|
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
|
9
9
|
|
|
10
|
+
# flex attention
|
|
11
|
+
# https://pytorch.org/blog/flexattention/
|
|
12
|
+
|
|
13
|
+
flex_attention = None
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
flex_attention = torch.compile(flex_attention)
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
23
|
+
|
|
24
|
+
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
25
|
+
is_persist_mem = kv_idx < persist_mem_len
|
|
26
|
+
causal_mask = q_idx >= (kv_idx - is_persist_mem)
|
|
27
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
|
|
28
|
+
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
29
|
+
|
|
30
|
+
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
31
|
+
return block_mask
|
|
32
|
+
|
|
33
|
+
# einstein notation related
|
|
34
|
+
|
|
10
35
|
from einops import einsum, repeat, rearrange, pack, unpack
|
|
11
36
|
from einops.layers.torch import Rearrange
|
|
12
37
|
|
|
13
|
-
|
|
38
|
+
# b - batch
|
|
39
|
+
# n - sequence
|
|
40
|
+
# h - heads
|
|
41
|
+
# d - feature dimension
|
|
14
42
|
|
|
15
43
|
# absolute and relative positions
|
|
16
44
|
|
|
17
45
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
|
18
46
|
from rotary_embedding_torch import RotaryEmbedding
|
|
47
|
+
|
|
48
|
+
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
|
49
|
+
|
|
19
50
|
from x_transformers.attend import Attend
|
|
51
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
20
52
|
|
|
21
53
|
# proposed neural memory
|
|
22
54
|
|
|
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
|
|
|
56
56
|
|
|
57
57
|
return packed, inverse
|
|
58
58
|
|
|
59
|
+
def Sequential(*modules):
|
|
60
|
+
modules = [*filter(exists, modules)]
|
|
61
|
+
|
|
62
|
+
if len(modules) == 0:
|
|
63
|
+
return nn.Identity()
|
|
64
|
+
|
|
65
|
+
if len(modules) == 1:
|
|
66
|
+
return modules[0]
|
|
67
|
+
|
|
68
|
+
return nn.Sequential(*modules)
|
|
69
|
+
|
|
59
70
|
# softclamping gradients
|
|
60
71
|
|
|
61
72
|
def softclamp_max(t, max_value):
|
|
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
|
|
|
124
135
|
])
|
|
125
136
|
|
|
126
137
|
def forward(self, x):
|
|
127
|
-
|
|
128
|
-
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
129
|
-
|
|
130
138
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
131
139
|
|
|
132
140
|
q = F.normalize(x @ wq, dim = -1)
|
|
@@ -162,11 +170,13 @@ class NeuralMemory(Module):
|
|
|
162
170
|
heads = 1,
|
|
163
171
|
model: Module | None = None,
|
|
164
172
|
store_memory_loss_fn: Callable = default_loss_fn,
|
|
165
|
-
adaptive_step_transform: Callable =
|
|
173
|
+
adaptive_step_transform: Callable | None = None,
|
|
174
|
+
default_step_transform_max_lr = 1e-2,
|
|
166
175
|
pre_rmsnorm = True,
|
|
167
176
|
post_rmsnorm = True,
|
|
168
177
|
max_grad_norm: float | None = None,
|
|
169
178
|
use_accelerated_scan = False,
|
|
179
|
+
activation: Module | None = None,
|
|
170
180
|
default_model_kwargs: dict = dict(
|
|
171
181
|
depth = 2
|
|
172
182
|
)
|
|
@@ -224,11 +234,11 @@ class NeuralMemory(Module):
|
|
|
224
234
|
|
|
225
235
|
# queries for retrieving from the model
|
|
226
236
|
|
|
227
|
-
self.to_queries = LinearNoBias(dim, dim_inner)
|
|
237
|
+
self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
|
|
228
238
|
|
|
229
239
|
# keys and values for storing to the model
|
|
230
240
|
|
|
231
|
-
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
|
241
|
+
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
|
232
242
|
self.store_memory_loss_fn = store_memory_loss_fn
|
|
233
243
|
|
|
234
244
|
# empty memory embed
|
|
@@ -250,6 +260,9 @@ class NeuralMemory(Module):
|
|
|
250
260
|
Rearrange('b n h -> (b h) n')
|
|
251
261
|
)
|
|
252
262
|
|
|
263
|
+
if not exists(adaptive_step_transform):
|
|
264
|
+
adaptive_step_transform = partial(default_adaptive_step_transform, max_lr = default_step_transform_max_lr)
|
|
265
|
+
|
|
253
266
|
self.adaptive_step_transform = adaptive_step_transform
|
|
254
267
|
|
|
255
268
|
# allow for softclamp the gradient norms for storing memories
|
|
@@ -4,7 +4,7 @@ import gzip
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from torch import nn
|
|
7
|
+
from torch import nn, Tensor
|
|
8
8
|
from torch.optim import Adam
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
from torch.utils.data import DataLoader, Dataset
|
|
@@ -19,19 +19,20 @@ GRADIENT_ACCUMULATE_EVERY = 4
|
|
|
19
19
|
LEARNING_RATE = 2e-4
|
|
20
20
|
VALIDATE_EVERY = 100
|
|
21
21
|
GENERATE_EVERY = 500
|
|
22
|
+
PRIME_LENGTH = 100
|
|
22
23
|
GENERATE_LENGTH = 512
|
|
23
|
-
SHOULD_GENERATE =
|
|
24
|
+
SHOULD_GENERATE = True
|
|
24
25
|
SEQ_LEN = 512
|
|
25
26
|
|
|
26
27
|
PROJECT_NAME = 'titans-mac-transformer'
|
|
27
|
-
WANDB_ONLINE =
|
|
28
|
+
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
28
29
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
30
|
NUM_PERSIST_MEM = 4
|
|
30
31
|
NUM_LONGTERM_MEM = 4
|
|
31
32
|
NEURAL_MEM_LAYERS = (2, 4)
|
|
32
33
|
WINDOW_SIZE = 32
|
|
33
|
-
KV_RECON_LOSS_WEIGHT = 0.
|
|
34
|
-
RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}
|
|
34
|
+
KV_RECON_LOSS_WEIGHT = 0.
|
|
35
|
+
RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
|
|
35
36
|
|
|
36
37
|
# wandb experiment tracker
|
|
37
38
|
|
|
@@ -53,6 +54,52 @@ def decode_token(token):
|
|
|
53
54
|
def decode_tokens(tokens):
|
|
54
55
|
return ''.join(list(map(decode_token, tokens)))
|
|
55
56
|
|
|
57
|
+
# sampling helpers
|
|
58
|
+
|
|
59
|
+
def log(t, eps = 1e-20):
|
|
60
|
+
return torch.log(t.clamp(min = eps))
|
|
61
|
+
|
|
62
|
+
def gumbel_noise(t):
|
|
63
|
+
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
64
|
+
return -log(-log(noise))
|
|
65
|
+
|
|
66
|
+
def gumbel_sample(t, temperature = 1., keepdim = True):
|
|
67
|
+
if temperature <= 0.:
|
|
68
|
+
return t.argmax(dim = dim, keepdim = keepdim)
|
|
69
|
+
|
|
70
|
+
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = -1, keepdim = keepdim)
|
|
71
|
+
|
|
72
|
+
# min_p
|
|
73
|
+
# https://arxiv.org/abs/2407.01082
|
|
74
|
+
|
|
75
|
+
def min_p_filter(logits, min_p = 0.1):
|
|
76
|
+
probs = logits.softmax(dim = -1)
|
|
77
|
+
max_probs = probs.amax(dim = -1, keepdim = True)
|
|
78
|
+
limit = min_p * max_probs
|
|
79
|
+
return torch.where(probs < limit, float('-inf'), logits)
|
|
80
|
+
|
|
81
|
+
def base_decoding(
|
|
82
|
+
net,
|
|
83
|
+
prompt: Tensor,
|
|
84
|
+
seq_len: int,
|
|
85
|
+
temperature = 1.5,
|
|
86
|
+
min_p = 1e-1,
|
|
87
|
+
filter_thres = 0.9,
|
|
88
|
+
):
|
|
89
|
+
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
|
|
90
|
+
sample_num_times = max(0, seq_len - prompt_seq_len)
|
|
91
|
+
|
|
92
|
+
for _ in tqdm.tqdm(range(sample_num_times)):
|
|
93
|
+
logits = net(out)
|
|
94
|
+
logits = logits[:, -1]
|
|
95
|
+
|
|
96
|
+
logits = min_p_filter(logits, min_p = min_p)
|
|
97
|
+
sample = gumbel_sample(logits, temperature = temperature)
|
|
98
|
+
|
|
99
|
+
out = torch.cat((out, sample), dim = -1)
|
|
100
|
+
|
|
101
|
+
return out[..., prompt_seq_len:]
|
|
102
|
+
|
|
56
103
|
# instantiate memory-as-context transformer
|
|
57
104
|
|
|
58
105
|
model = MemoryAsContextTransformer(
|
|
@@ -127,10 +174,10 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
127
174
|
|
|
128
175
|
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
129
176
|
model.eval()
|
|
130
|
-
inp = random.choice(val_dataset)[
|
|
177
|
+
inp = random.choice(val_dataset)[:PRIME_LENGTH]
|
|
131
178
|
prime = decode_tokens(inp)
|
|
132
179
|
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
133
180
|
|
|
134
|
-
sample = model
|
|
181
|
+
sample = base_decoding(model, inp[None, ...], GENERATE_LENGTH)
|
|
135
182
|
output_str = decode_tokens(sample[0])
|
|
136
183
|
print(output_str)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|