titans-pytorch 0.0.41__tar.gz → 0.0.43__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/PKG-INFO +22 -10
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/README.md +21 -7
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/pyproject.toml +1 -3
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/tests/test_titans.py +5 -4
- titans_pytorch-0.0.43/titans_pytorch/__init__.py +8 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/titans_pytorch/mac_transformer.py +0 -1
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/titans_pytorch/titans.py +43 -12
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/train_mac.py +3 -3
- titans_pytorch-0.0.41/requirements.txt +0 -1
- titans_pytorch-0.0.41/titans_pytorch/__init__.py +0 -6
- titans_pytorch-0.0.41/titans_pytorch/titans_attn_memory.py +0 -419
- titans_pytorch-0.0.41/train.py +0 -152
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/.gitignore +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/LICENSE +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/data/README.md +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/fig1.png +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/fig2.png +0 -0
- {titans_pytorch-0.0.41 → titans_pytorch-0.0.43}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.43
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -45,8 +45,6 @@ Requires-Dist: tensordict
|
|
|
45
45
|
Requires-Dist: torch>=2.2
|
|
46
46
|
Requires-Dist: x-transformers
|
|
47
47
|
Provides-Extra: examples
|
|
48
|
-
Requires-Dist: local-attention>=1.10.1; extra == 'examples'
|
|
49
|
-
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
50
48
|
Requires-Dist: tqdm; extra == 'examples'
|
|
51
49
|
Requires-Dist: wandb; extra == 'examples'
|
|
52
50
|
Provides-Extra: test
|
|
@@ -85,22 +83,36 @@ retrieved = mem(seq)
|
|
|
85
83
|
assert seq.shape == retrieved.shape
|
|
86
84
|
```
|
|
87
85
|
|
|
88
|
-
|
|
86
|
+
A transformer with the `MAC` configuration can be used as
|
|
89
87
|
|
|
90
|
-
```
|
|
91
|
-
|
|
88
|
+
```python
|
|
89
|
+
import torch
|
|
90
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
91
|
+
|
|
92
|
+
transformer = MemoryAsContextTransformer(
|
|
93
|
+
num_tokens = 256,
|
|
94
|
+
dim = 256,
|
|
95
|
+
depth = 2,
|
|
96
|
+
segment_len = 128, # local attention window size
|
|
97
|
+
num_persist_mem_tokens = 4,
|
|
98
|
+
num_longterm_mem_tokens = 16,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
token_ids = torch.randint(0, 256, (1, 1023))
|
|
102
|
+
|
|
103
|
+
logits = transformer(token_ids) # (1, 1023, 256)
|
|
92
104
|
```
|
|
93
105
|
|
|
94
|
-
|
|
106
|
+
## Experiments
|
|
95
107
|
|
|
96
108
|
```bash
|
|
97
|
-
$ pip install
|
|
109
|
+
$ pip install .[examples]
|
|
98
110
|
```
|
|
99
111
|
|
|
100
|
-
Then modify `
|
|
112
|
+
Then modify `train_mac.py` and run it to query nature
|
|
101
113
|
|
|
102
114
|
```bash
|
|
103
|
-
$ python
|
|
115
|
+
$ python train_mac.py
|
|
104
116
|
```
|
|
105
117
|
|
|
106
118
|
## Citations
|
|
@@ -30,22 +30,36 @@ retrieved = mem(seq)
|
|
|
30
30
|
assert seq.shape == retrieved.shape
|
|
31
31
|
```
|
|
32
32
|
|
|
33
|
-
|
|
33
|
+
A transformer with the `MAC` configuration can be used as
|
|
34
34
|
|
|
35
|
-
```
|
|
36
|
-
|
|
35
|
+
```python
|
|
36
|
+
import torch
|
|
37
|
+
from titans_pytorch import MemoryAsContextTransformer
|
|
38
|
+
|
|
39
|
+
transformer = MemoryAsContextTransformer(
|
|
40
|
+
num_tokens = 256,
|
|
41
|
+
dim = 256,
|
|
42
|
+
depth = 2,
|
|
43
|
+
segment_len = 128, # local attention window size
|
|
44
|
+
num_persist_mem_tokens = 4,
|
|
45
|
+
num_longterm_mem_tokens = 16,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
token_ids = torch.randint(0, 256, (1, 1023))
|
|
49
|
+
|
|
50
|
+
logits = transformer(token_ids) # (1, 1023, 256)
|
|
37
51
|
```
|
|
38
52
|
|
|
39
|
-
|
|
53
|
+
## Experiments
|
|
40
54
|
|
|
41
55
|
```bash
|
|
42
|
-
$ pip install
|
|
56
|
+
$ pip install .[examples]
|
|
43
57
|
```
|
|
44
58
|
|
|
45
|
-
Then modify `
|
|
59
|
+
Then modify `train_mac.py` and run it to query nature
|
|
46
60
|
|
|
47
61
|
```bash
|
|
48
|
-
$ python
|
|
62
|
+
$ python train_mac.py
|
|
49
63
|
```
|
|
50
64
|
|
|
51
65
|
## Citations
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.43"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -44,8 +44,6 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
|
44
44
|
[project.optional-dependencies]
|
|
45
45
|
|
|
46
46
|
examples = [
|
|
47
|
-
"local-attention>=1.10.1",
|
|
48
|
-
"taylor-series-linear-attention",
|
|
49
47
|
"tqdm",
|
|
50
48
|
"wandb"
|
|
51
49
|
]
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import pytest
|
|
3
|
+
from titans_pytorch import NeuralMemory
|
|
3
4
|
|
|
4
5
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
5
6
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
@@ -7,9 +8,6 @@ def test_titans(
|
|
|
7
8
|
seq_len,
|
|
8
9
|
max_grad_norm
|
|
9
10
|
):
|
|
10
|
-
|
|
11
|
-
from titans_pytorch import NeuralMemory
|
|
12
|
-
|
|
13
11
|
mem = NeuralMemory(
|
|
14
12
|
dim = 384,
|
|
15
13
|
chunk_size = 64,
|
|
@@ -22,11 +20,14 @@ def test_titans(
|
|
|
22
20
|
assert seq.shape == retrieved.shape
|
|
23
21
|
|
|
24
22
|
def test_titans_attn_memory():
|
|
25
|
-
from titans_pytorch.
|
|
23
|
+
from titans_pytorch.titans import MemoryAttention
|
|
26
24
|
|
|
27
25
|
mem = NeuralMemory(
|
|
28
26
|
dim = 384,
|
|
29
27
|
chunk_size = 64,
|
|
28
|
+
model = MemoryAttention(
|
|
29
|
+
dim = 384
|
|
30
|
+
)
|
|
30
31
|
)
|
|
31
32
|
|
|
32
33
|
seq = torch.randn(2, 1024, 384)
|
|
@@ -27,7 +27,7 @@ n - sequence
|
|
|
27
27
|
d - feature dimension
|
|
28
28
|
c - intra-chunk
|
|
29
29
|
"""
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
LinearNoBias = partial(Linear, bias = False)
|
|
32
32
|
|
|
33
33
|
# functions
|
|
@@ -107,6 +107,44 @@ class MemoryMLP(Module):
|
|
|
107
107
|
|
|
108
108
|
return x
|
|
109
109
|
|
|
110
|
+
# improvised attention as memory module
|
|
111
|
+
|
|
112
|
+
class MemoryAttention(Module):
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
dim
|
|
116
|
+
):
|
|
117
|
+
super().__init__()
|
|
118
|
+
self.weights = nn.ParameterList([
|
|
119
|
+
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
120
|
+
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
121
|
+
nn.Parameter(torch.randn(dim, dim)), # values
|
|
122
|
+
nn.Parameter(torch.randn(dim, dim * 2)), # ff w1
|
|
123
|
+
nn.Parameter(torch.randn(dim * 2, dim)), # ff w2
|
|
124
|
+
])
|
|
125
|
+
|
|
126
|
+
def forward(self, x):
|
|
127
|
+
|
|
128
|
+
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
129
|
+
|
|
130
|
+
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
131
|
+
|
|
132
|
+
q = F.normalize(x @ wq, dim = -1)
|
|
133
|
+
k = F.normalize(x @ wk, dim = -1)
|
|
134
|
+
v = x @ wv
|
|
135
|
+
|
|
136
|
+
attn_out = F.scaled_dot_product_attention(
|
|
137
|
+
q, k, v,
|
|
138
|
+
is_causal = True
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
x = x + attn_out
|
|
142
|
+
|
|
143
|
+
h = F.silu(x @ ffw1)
|
|
144
|
+
out = h @ ffw2
|
|
145
|
+
|
|
146
|
+
return out
|
|
147
|
+
|
|
110
148
|
# main neural memory
|
|
111
149
|
|
|
112
150
|
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
|
@@ -129,7 +167,7 @@ class NeuralMemory(Module):
|
|
|
129
167
|
post_rmsnorm = True,
|
|
130
168
|
max_grad_norm: float | None = None,
|
|
131
169
|
use_accelerated_scan = False,
|
|
132
|
-
|
|
170
|
+
default_model_kwargs: dict = dict(
|
|
133
171
|
depth = 2
|
|
134
172
|
)
|
|
135
173
|
):
|
|
@@ -162,7 +200,7 @@ class NeuralMemory(Module):
|
|
|
162
200
|
# memory mlp
|
|
163
201
|
|
|
164
202
|
if not exists(model):
|
|
165
|
-
model = MemoryMLP(dim_head, **
|
|
203
|
+
model = MemoryMLP(dim_head, **default_model_kwargs)
|
|
166
204
|
|
|
167
205
|
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
|
168
206
|
|
|
@@ -387,11 +425,7 @@ class NeuralMemory(Module):
|
|
|
387
425
|
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
|
|
388
426
|
|
|
389
427
|
padding = next_seq_len - curtailed_seq_len
|
|
390
|
-
|
|
391
|
-
needs_pad = padding > 0
|
|
392
|
-
|
|
393
|
-
if needs_pad:
|
|
394
|
-
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
428
|
+
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
395
429
|
|
|
396
430
|
# the parameters of the memory model stores the memories of the key / values
|
|
397
431
|
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
@@ -443,10 +477,7 @@ class NeuralMemory(Module):
|
|
|
443
477
|
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
|
444
478
|
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
|
445
479
|
|
|
446
|
-
|
|
447
|
-
values = values[:, :-padding]
|
|
448
|
-
|
|
449
|
-
return values
|
|
480
|
+
return values[:, :seq_len]
|
|
450
481
|
|
|
451
482
|
def forward(
|
|
452
483
|
self,
|
|
@@ -28,9 +28,9 @@ WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
|
28
28
|
NEURAL_MEMORY_DEPTH = 2
|
|
29
29
|
NUM_PERSIST_MEM = 4
|
|
30
30
|
NUM_LONGTERM_MEM = 4
|
|
31
|
-
NEURAL_MEM_LAYERS = (4
|
|
31
|
+
NEURAL_MEM_LAYERS = (2, 4)
|
|
32
32
|
WINDOW_SIZE = 32
|
|
33
|
-
RUN_NAME = 'mac -
|
|
33
|
+
RUN_NAME = f'mac - {NUM_LONGTERM_MEM} longterm mems, layers {NEURAL_MEM_LAYERS}'
|
|
34
34
|
|
|
35
35
|
# wandb experiment tracker
|
|
36
36
|
|
|
@@ -65,7 +65,7 @@ model = MemoryAsContextTransformer(
|
|
|
65
65
|
neural_memory_kwargs = dict(
|
|
66
66
|
dim_head = 64,
|
|
67
67
|
heads = 4,
|
|
68
|
-
|
|
68
|
+
default_model_kwargs = dict(
|
|
69
69
|
depth = NEURAL_MEMORY_DEPTH,
|
|
70
70
|
)
|
|
71
71
|
)
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
pytorch-fast-transformers>=0.4.0
|
|
@@ -1,419 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
import math
|
|
3
|
-
from functools import partial
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn, Tensor
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
from torch.nn import Linear, Module
|
|
9
|
-
from torch.func import functional_call, vmap, grad
|
|
10
|
-
|
|
11
|
-
from tensordict import TensorDict
|
|
12
|
-
|
|
13
|
-
from titans_pytorch.associative_scan import (
|
|
14
|
-
associative_scan,
|
|
15
|
-
binary_operator,
|
|
16
|
-
pad_at_dim
|
|
17
|
-
)
|
|
18
|
-
|
|
19
|
-
import einx
|
|
20
|
-
from einops import rearrange, pack, unpack
|
|
21
|
-
from einops.layers.torch import Rearrange, Reduce
|
|
22
|
-
|
|
23
|
-
"""
|
|
24
|
-
ein notation:
|
|
25
|
-
b - batch
|
|
26
|
-
n - sequence
|
|
27
|
-
d - feature dimension
|
|
28
|
-
c - intra-chunk
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
# constants
|
|
32
|
-
|
|
33
|
-
LinearNoBias = partial(Linear, bias = False)
|
|
34
|
-
|
|
35
|
-
# functions
|
|
36
|
-
|
|
37
|
-
def exists(v):
|
|
38
|
-
return v is not None
|
|
39
|
-
|
|
40
|
-
def default(v, d):
|
|
41
|
-
return v if exists(v) else d
|
|
42
|
-
|
|
43
|
-
def round_down_multiple(seq, mult):
|
|
44
|
-
return seq // mult * mult
|
|
45
|
-
|
|
46
|
-
def round_up_multiple(seq, mult):
|
|
47
|
-
return math.ceil(seq / mult) * mult
|
|
48
|
-
|
|
49
|
-
def pack_one_with_inverse(t, pattern):
|
|
50
|
-
packed, packed_shape = pack([t], pattern)
|
|
51
|
-
|
|
52
|
-
def inverse(out, inv_pattern = None):
|
|
53
|
-
inv_pattern = default(inv_pattern, pattern)
|
|
54
|
-
return unpack(out, packed_shape, inv_pattern)[0]
|
|
55
|
-
|
|
56
|
-
return packed, inverse
|
|
57
|
-
|
|
58
|
-
# classes
|
|
59
|
-
|
|
60
|
-
# improvised attention as memory module
|
|
61
|
-
# todo - expand if see signal in experiments (update: not seeing it)
|
|
62
|
-
|
|
63
|
-
class MemoryAttention(Module):
|
|
64
|
-
def __init__(
|
|
65
|
-
self,
|
|
66
|
-
dim
|
|
67
|
-
):
|
|
68
|
-
super().__init__()
|
|
69
|
-
self.weights = nn.ParameterList([
|
|
70
|
-
nn.Parameter(torch.randn(dim, dim)), # queries
|
|
71
|
-
nn.Parameter(torch.randn(dim, dim)), # keys
|
|
72
|
-
nn.Parameter(torch.randn(dim, dim)), # values weight 1
|
|
73
|
-
nn.Parameter(torch.randn(dim, dim)), # values weight 2
|
|
74
|
-
])
|
|
75
|
-
|
|
76
|
-
def forward(self, x):
|
|
77
|
-
|
|
78
|
-
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
79
|
-
|
|
80
|
-
wq, wk, wv1, wv2 = self.weights
|
|
81
|
-
|
|
82
|
-
q = x @ wq
|
|
83
|
-
k = x @ wk
|
|
84
|
-
v = x @ wv1
|
|
85
|
-
|
|
86
|
-
hidden = F.scaled_dot_product_attention(
|
|
87
|
-
q, k, v,
|
|
88
|
-
is_causal = True
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
return F.silu(hidden) @ wv2
|
|
92
|
-
|
|
93
|
-
# main neural memory
|
|
94
|
-
|
|
95
|
-
def default_loss_fn(pred, target):
|
|
96
|
-
return (pred - target).pow(2).mean(dim = -1).sum()
|
|
97
|
-
|
|
98
|
-
class NeuralMemory(Module):
|
|
99
|
-
def __init__(
|
|
100
|
-
self,
|
|
101
|
-
dim,
|
|
102
|
-
chunk_size = 1,
|
|
103
|
-
dim_head = None,
|
|
104
|
-
heads = 1,
|
|
105
|
-
model: MemoryAttention | None = None,
|
|
106
|
-
store_memory_loss_fn: Callable = default_loss_fn,
|
|
107
|
-
pre_rmsnorm = True,
|
|
108
|
-
post_rmsnorm = True,
|
|
109
|
-
use_accelerated_scan = False,
|
|
110
|
-
default_model_kwargs: dict = dict()
|
|
111
|
-
):
|
|
112
|
-
super().__init__()
|
|
113
|
-
|
|
114
|
-
# norms
|
|
115
|
-
|
|
116
|
-
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
117
|
-
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
|
|
118
|
-
|
|
119
|
-
self.post_rmsnorm = nn.RMSNorm(dim) if post_rmsnorm else nn.Identity()
|
|
120
|
-
|
|
121
|
-
# maybe multi-headed
|
|
122
|
-
|
|
123
|
-
dim_head = default(dim_head, dim)
|
|
124
|
-
dim_inner = dim_head * heads
|
|
125
|
-
|
|
126
|
-
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
|
|
127
|
-
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
|
128
|
-
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
|
129
|
-
|
|
130
|
-
# memory mlp
|
|
131
|
-
|
|
132
|
-
if not exists(model):
|
|
133
|
-
model = MemoryAttention(dim_head, **default_model_kwargs)
|
|
134
|
-
|
|
135
|
-
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
|
|
136
|
-
|
|
137
|
-
# the memory is the weights of the model
|
|
138
|
-
|
|
139
|
-
self.memory_model = model
|
|
140
|
-
|
|
141
|
-
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
142
|
-
|
|
143
|
-
self.chunk_size = chunk_size
|
|
144
|
-
|
|
145
|
-
# prepare function for per sample gradients from model above, using torch.func
|
|
146
|
-
|
|
147
|
-
def forward_and_loss(params, inputs, target):
|
|
148
|
-
pred = functional_call(self.memory_model, params, inputs)
|
|
149
|
-
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
|
150
|
-
return loss
|
|
151
|
-
|
|
152
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
|
153
|
-
|
|
154
|
-
# queries for retrieving from the model
|
|
155
|
-
|
|
156
|
-
self.to_queries = LinearNoBias(dim, dim_inner)
|
|
157
|
-
|
|
158
|
-
# keys and values for storing to the model
|
|
159
|
-
|
|
160
|
-
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
|
161
|
-
self.store_memory_loss_fn = store_memory_loss_fn
|
|
162
|
-
|
|
163
|
-
# learned adaptive learning rate and momentum
|
|
164
|
-
# todo - explore mlp layerwise learned lr / momentum
|
|
165
|
-
|
|
166
|
-
self.to_momentum = nn.Sequential(
|
|
167
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
168
|
-
LinearNoBias(dim, heads),
|
|
169
|
-
Rearrange('b n h -> (b h) n 1')
|
|
170
|
-
)
|
|
171
|
-
|
|
172
|
-
self.to_adaptive_step = nn.Sequential(
|
|
173
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
174
|
-
LinearNoBias(dim, heads),
|
|
175
|
-
Rearrange('b n h -> (b h) n')
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
# weight decay factor
|
|
179
|
-
|
|
180
|
-
self.to_decay_factor = nn.Sequential(
|
|
181
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
182
|
-
LinearNoBias(dim, heads),
|
|
183
|
-
Rearrange('b n h -> (b h) n 1')
|
|
184
|
-
)
|
|
185
|
-
|
|
186
|
-
# maybe use accelerated scan
|
|
187
|
-
|
|
188
|
-
self.use_accelerated_scan = use_accelerated_scan
|
|
189
|
-
|
|
190
|
-
def init_weights_and_momentum(self):
|
|
191
|
-
params = TensorDict(dict(self.memory_model.named_parameters()))
|
|
192
|
-
|
|
193
|
-
init_weights = params.clone().zero_()
|
|
194
|
-
init_momentum = params.clone().zero_()
|
|
195
|
-
|
|
196
|
-
return init_weights, init_momentum
|
|
197
|
-
|
|
198
|
-
def store_memories(
|
|
199
|
-
self,
|
|
200
|
-
seq,
|
|
201
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
|
|
202
|
-
):
|
|
203
|
-
|
|
204
|
-
seq = self.store_norm(seq)
|
|
205
|
-
|
|
206
|
-
# curtail sequence by multiple of the chunk size
|
|
207
|
-
# only a complete chunk of the sequence provides the memory for the next chunk
|
|
208
|
-
|
|
209
|
-
seq_len, chunk_size = seq.shape[-2], self.chunk_size
|
|
210
|
-
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
|
|
211
|
-
|
|
212
|
-
seq = seq[:, :round_down_seq_len]
|
|
213
|
-
|
|
214
|
-
# curr weights + past weights, in the case that the initial weights are learned
|
|
215
|
-
|
|
216
|
-
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
217
|
-
|
|
218
|
-
past_state = tuple(TensorDict(d) for d in past_state)
|
|
219
|
-
past_weights, past_momentum = past_state
|
|
220
|
-
|
|
221
|
-
curr_weights = curr_weights + past_weights
|
|
222
|
-
|
|
223
|
-
# pack batch and sequence dimension
|
|
224
|
-
|
|
225
|
-
adaptive_lr = (self.to_adaptive_step(seq).sigmoid() * -15).exp() # from 1. - 1e-7
|
|
226
|
-
|
|
227
|
-
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
228
|
-
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
229
|
-
|
|
230
|
-
# keys and values
|
|
231
|
-
|
|
232
|
-
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
233
|
-
|
|
234
|
-
# maybe multi head
|
|
235
|
-
|
|
236
|
-
keys, values = map(self.split_heads, (keys, values))
|
|
237
|
-
|
|
238
|
-
batch = keys.shape[0]
|
|
239
|
-
|
|
240
|
-
# take care of chunking
|
|
241
|
-
|
|
242
|
-
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
|
243
|
-
|
|
244
|
-
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
|
245
|
-
|
|
246
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
|
247
|
-
|
|
248
|
-
grads = TensorDict(grads)
|
|
249
|
-
|
|
250
|
-
# restore batch and sequence dimension
|
|
251
|
-
|
|
252
|
-
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
|
253
|
-
|
|
254
|
-
# multiply gradients with learned adaptive step size
|
|
255
|
-
|
|
256
|
-
surprises = grads.apply(lambda t: einx.multiply('b n ..., b n -> b n ...', t, -adaptive_lr))
|
|
257
|
-
|
|
258
|
-
# determine scan function
|
|
259
|
-
|
|
260
|
-
def default_associative_scan(gates, inputs):
|
|
261
|
-
_, outputs = associative_scan(binary_operator, (gates, inputs))
|
|
262
|
-
return outputs
|
|
263
|
-
|
|
264
|
-
if self.use_accelerated_scan:
|
|
265
|
-
from accelerated_scan.triton import scan as triton_scan
|
|
266
|
-
from accelerated_scan.warp import scan as warp_scan
|
|
267
|
-
|
|
268
|
-
scan = triton_scan if seq.is_cuda else warp_scan
|
|
269
|
-
|
|
270
|
-
def accelerate_scan_fn(gates, inputs):
|
|
271
|
-
gates = gates.expand_as(inputs)
|
|
272
|
-
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
|
|
273
|
-
|
|
274
|
-
seq_len = gates.shape[-1]
|
|
275
|
-
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
|
|
276
|
-
|
|
277
|
-
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
|
|
278
|
-
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
|
|
279
|
-
|
|
280
|
-
outputs = scan(gates, inputs)
|
|
281
|
-
|
|
282
|
-
outputs = outputs[..., :seq_len]
|
|
283
|
-
outputs = rearrange(outputs, 'b d n -> b n d')
|
|
284
|
-
return outputs
|
|
285
|
-
|
|
286
|
-
scan_fn = accelerate_scan_fn
|
|
287
|
-
else:
|
|
288
|
-
scan_fn = default_associative_scan
|
|
289
|
-
|
|
290
|
-
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
|
291
|
-
|
|
292
|
-
next_momentum = TensorDict()
|
|
293
|
-
updates = TensorDict()
|
|
294
|
-
|
|
295
|
-
for param_name, surprise in surprises.items():
|
|
296
|
-
|
|
297
|
-
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
|
|
298
|
-
|
|
299
|
-
# derive momentum with associative scan - eq (10)
|
|
300
|
-
|
|
301
|
-
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
|
|
302
|
-
|
|
303
|
-
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
304
|
-
|
|
305
|
-
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
|
|
306
|
-
|
|
307
|
-
updates[param_name] = inverse_pack(update)
|
|
308
|
-
next_momentum[param_name] = inverse_pack(momentum)
|
|
309
|
-
|
|
310
|
-
# compute the next weight per batch
|
|
311
|
-
|
|
312
|
-
last_update = updates.apply(lambda t: t[:, -1])
|
|
313
|
-
|
|
314
|
-
next_state = (curr_weights + last_update, next_momentum)
|
|
315
|
-
|
|
316
|
-
return updates, next_state
|
|
317
|
-
|
|
318
|
-
def retrieve_memories(
|
|
319
|
-
self,
|
|
320
|
-
seq,
|
|
321
|
-
past_weights: dict[str, Tensor] | None = None,
|
|
322
|
-
):
|
|
323
|
-
chunk_size = self.chunk_size
|
|
324
|
-
seq_len = seq.shape[1]
|
|
325
|
-
|
|
326
|
-
seq = self.retrieve_norm(seq)
|
|
327
|
-
|
|
328
|
-
assert seq_len > chunk_size
|
|
329
|
-
|
|
330
|
-
seq = seq[:, chunk_size:]
|
|
331
|
-
curtailed_seq_len = seq.shape[-2]
|
|
332
|
-
|
|
333
|
-
next_seq_len = round_up_multiple(curtailed_seq_len + 1, chunk_size)
|
|
334
|
-
|
|
335
|
-
padding = next_seq_len - curtailed_seq_len
|
|
336
|
-
|
|
337
|
-
seq = pad_at_dim(seq, (0, padding), dim = 1)
|
|
338
|
-
|
|
339
|
-
# the parameters of the memory model stores the memories of the key / values
|
|
340
|
-
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
|
|
341
|
-
|
|
342
|
-
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
|
|
343
|
-
|
|
344
|
-
if exists(past_weights):
|
|
345
|
-
past_weights = TensorDict(past_weights)
|
|
346
|
-
assert past_weights.keys() == curr_weights.keys()
|
|
347
|
-
|
|
348
|
-
curr_weights = curr_weights + past_weights
|
|
349
|
-
|
|
350
|
-
# sequence Float['b n d'] to queries
|
|
351
|
-
|
|
352
|
-
queries = self.to_queries(seq)
|
|
353
|
-
|
|
354
|
-
# maybe multihead
|
|
355
|
-
|
|
356
|
-
queries = self.split_heads(queries)
|
|
357
|
-
|
|
358
|
-
batch = queries.shape[0]
|
|
359
|
-
|
|
360
|
-
# fetch values from memory model
|
|
361
|
-
|
|
362
|
-
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
|
|
363
|
-
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
|
|
364
|
-
|
|
365
|
-
# forward functional call
|
|
366
|
-
|
|
367
|
-
values = functional_call(self.memory_model, dict(curr_weights), queries)
|
|
368
|
-
|
|
369
|
-
# reconstitute batch dimension
|
|
370
|
-
|
|
371
|
-
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
|
372
|
-
|
|
373
|
-
# maybe merge heads and combine
|
|
374
|
-
|
|
375
|
-
values = self.merge_heads(values)
|
|
376
|
-
|
|
377
|
-
values = self.combine_heads(values)
|
|
378
|
-
|
|
379
|
-
# post norm, somehow could not stabilize this without it, not in paper
|
|
380
|
-
|
|
381
|
-
values = self.post_rmsnorm(values)
|
|
382
|
-
|
|
383
|
-
# restore
|
|
384
|
-
|
|
385
|
-
values = pad_at_dim(values, (chunk_size, 0), dim = 1, value = 0.) # todo, used a learned null memory embedding instead of 0s for retrieving from empty neural memory
|
|
386
|
-
values = values[:, :-padding]
|
|
387
|
-
|
|
388
|
-
return values
|
|
389
|
-
|
|
390
|
-
def forward(
|
|
391
|
-
self,
|
|
392
|
-
seq,
|
|
393
|
-
store_seq = None,
|
|
394
|
-
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
395
|
-
return_next_memories = False
|
|
396
|
-
):
|
|
397
|
-
batch, seq_len = seq.shape[:2]
|
|
398
|
-
|
|
399
|
-
if seq_len <= self.chunk_size:
|
|
400
|
-
return torch.zeros_like(seq)
|
|
401
|
-
|
|
402
|
-
if exists(past_state):
|
|
403
|
-
past_state = tuple(TensorDict(d) for d in past_state)
|
|
404
|
-
|
|
405
|
-
if not exists(past_state):
|
|
406
|
-
past_state = self.init_weights_and_momentum()
|
|
407
|
-
|
|
408
|
-
store_seq = default(store_seq, seq)
|
|
409
|
-
|
|
410
|
-
updates, next_memories = self.store_memories(store_seq, past_state)
|
|
411
|
-
|
|
412
|
-
past_weights, _ = past_state
|
|
413
|
-
|
|
414
|
-
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
415
|
-
|
|
416
|
-
if not return_next_memories:
|
|
417
|
-
return retrieved
|
|
418
|
-
|
|
419
|
-
return retrieved, next_memories
|
titans_pytorch-0.0.41/train.py
DELETED
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
import random
|
|
2
|
-
import tqdm
|
|
3
|
-
import gzip
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from torch import nn
|
|
8
|
-
from torch.optim import Adam
|
|
9
|
-
from torch.nn import functional as F
|
|
10
|
-
from torch.utils.data import DataLoader, Dataset
|
|
11
|
-
|
|
12
|
-
from local_attention import LocalTransformer
|
|
13
|
-
|
|
14
|
-
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
|
15
|
-
|
|
16
|
-
from titans_pytorch.titans import (
|
|
17
|
-
NeuralMemory,
|
|
18
|
-
MemoryMLP
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
# constants
|
|
22
|
-
|
|
23
|
-
NUM_BATCHES = int(1e5)
|
|
24
|
-
BATCH_SIZE = 4
|
|
25
|
-
GRADIENT_ACCUMULATE_EVERY = 4
|
|
26
|
-
LEARNING_RATE = 2e-4
|
|
27
|
-
VALIDATE_EVERY = 100
|
|
28
|
-
GENERATE_EVERY = 500
|
|
29
|
-
GENERATE_LENGTH = 512
|
|
30
|
-
SHOULD_GENERATE = True
|
|
31
|
-
SEQ_LEN = 512
|
|
32
|
-
|
|
33
|
-
PROJECT_NAME = 'titans-neural-memory'
|
|
34
|
-
WANDB_ONLINE = False # turn this on to pipe experiment to cloud
|
|
35
|
-
GLOBAL_LAYERS = (2, 4)
|
|
36
|
-
USE_TITANS_MEMORY = True
|
|
37
|
-
NEURAL_MEMORY_DEPTH = 2
|
|
38
|
-
WINDOW_SIZE = 64
|
|
39
|
-
RUN_NAME = 'neural memory'
|
|
40
|
-
|
|
41
|
-
# wandb experiment tracker
|
|
42
|
-
|
|
43
|
-
import wandb
|
|
44
|
-
wandb.init(project = PROJECT_NAME, mode = 'disabled' if not WANDB_ONLINE else 'online')
|
|
45
|
-
wandb.run.name = RUN_NAME
|
|
46
|
-
wandb.run.save()
|
|
47
|
-
|
|
48
|
-
# helpers
|
|
49
|
-
|
|
50
|
-
def cycle(loader):
|
|
51
|
-
while True:
|
|
52
|
-
for data in loader:
|
|
53
|
-
yield data
|
|
54
|
-
|
|
55
|
-
def decode_token(token):
|
|
56
|
-
return str(chr(max(32, token)))
|
|
57
|
-
|
|
58
|
-
def decode_tokens(tokens):
|
|
59
|
-
return ''.join(list(map(decode_token, tokens)))
|
|
60
|
-
|
|
61
|
-
# instantiate GPT-like decoder model
|
|
62
|
-
|
|
63
|
-
titans_neural_memory = NeuralMemory(
|
|
64
|
-
dim = 384,
|
|
65
|
-
chunk_size = 4,
|
|
66
|
-
dim_head = 64,
|
|
67
|
-
heads = 4,
|
|
68
|
-
use_accelerated_scan = True,
|
|
69
|
-
default_mlp_kwargs = dict(
|
|
70
|
-
depth = NEURAL_MEMORY_DEPTH
|
|
71
|
-
)
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
linear_attn = TaylorSeriesLinearAttn(
|
|
75
|
-
dim = 384,
|
|
76
|
-
dim_head = 16,
|
|
77
|
-
heads = 16,
|
|
78
|
-
causal = True,
|
|
79
|
-
prenorm = True
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
model = LocalTransformer(
|
|
83
|
-
num_tokens = 256,
|
|
84
|
-
dim = 384,
|
|
85
|
-
depth = 8,
|
|
86
|
-
causal = True,
|
|
87
|
-
local_attn_window_size = WINDOW_SIZE,
|
|
88
|
-
max_seq_len = SEQ_LEN,
|
|
89
|
-
global_attn_layer = linear_attn if not USE_TITANS_MEMORY else titans_neural_memory,
|
|
90
|
-
layers_insert_global_attn = GLOBAL_LAYERS
|
|
91
|
-
).cuda()
|
|
92
|
-
|
|
93
|
-
# prepare enwik8 data
|
|
94
|
-
|
|
95
|
-
with gzip.open('./data/enwik8.gz') as file:
|
|
96
|
-
data = np.frombuffer(file.read(int(95e6)), dtype = np.uint8).copy()
|
|
97
|
-
data_train, data_val = np.split(data, [int(90e6)])
|
|
98
|
-
data_train, data_val = map(torch.from_numpy, (data_train, data_val))
|
|
99
|
-
|
|
100
|
-
class TextSamplerDataset(Dataset):
|
|
101
|
-
def __init__(self, data, seq_len):
|
|
102
|
-
super().__init__()
|
|
103
|
-
self.data = data
|
|
104
|
-
self.seq_len = seq_len
|
|
105
|
-
|
|
106
|
-
def __getitem__(self, index):
|
|
107
|
-
rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
|
|
108
|
-
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
|
|
109
|
-
return full_seq.cuda()
|
|
110
|
-
|
|
111
|
-
def __len__(self):
|
|
112
|
-
return self.data.size(0) // self.seq_len
|
|
113
|
-
|
|
114
|
-
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
|
|
115
|
-
val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
|
|
116
|
-
train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
|
117
|
-
val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
|
118
|
-
|
|
119
|
-
# optimizer
|
|
120
|
-
|
|
121
|
-
optim = Adam(model.parameters(), lr=LEARNING_RATE)
|
|
122
|
-
|
|
123
|
-
# training
|
|
124
|
-
|
|
125
|
-
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
126
|
-
model.train()
|
|
127
|
-
|
|
128
|
-
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
129
|
-
loss = model(next(train_loader), return_loss = True)
|
|
130
|
-
loss.backward()
|
|
131
|
-
|
|
132
|
-
print(f'training loss: {loss.item()}')
|
|
133
|
-
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
|
134
|
-
optim.step()
|
|
135
|
-
optim.zero_grad()
|
|
136
|
-
wandb.log(dict(loss = loss.item()))
|
|
137
|
-
|
|
138
|
-
if i % VALIDATE_EVERY == 0:
|
|
139
|
-
model.eval()
|
|
140
|
-
with torch.no_grad():
|
|
141
|
-
loss = model(next(val_loader), return_loss = True)
|
|
142
|
-
print(f'validation loss: {loss.item()}')
|
|
143
|
-
|
|
144
|
-
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
145
|
-
model.eval()
|
|
146
|
-
inp = random.choice(val_dataset)[:-1]
|
|
147
|
-
prime = decode_tokens(inp)
|
|
148
|
-
print(f'%s \n\n %s', (prime, '*' * 100))
|
|
149
|
-
|
|
150
|
-
sample = model.generate(inp[None, ...], GENERATE_LENGTH, use_kv_cache = False)
|
|
151
|
-
output_str = decode_tokens(sample[0])
|
|
152
|
-
print(output_str)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|