titans-pytorch 0.0.7__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/PKG-INFO +22 -9
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/README.md +19 -7
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/pyproject.toml +3 -2
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/titans_pytorch/titans.py +13 -3
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/train.py +27 -3
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/.gitignore +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/LICENSE +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/data/README.md +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/fig1.png +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/fig2.png +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.7 → titans_pytorch-0.0.9}/titans_pytorch/associative_scan.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -39,7 +39,8 @@ Requires-Dist: einx>=0.3.0
|
|
|
39
39
|
Requires-Dist: tensordict>=0.6.2
|
|
40
40
|
Requires-Dist: torch>=2.3
|
|
41
41
|
Provides-Extra: examples
|
|
42
|
-
Requires-Dist: local-attention>=1.
|
|
42
|
+
Requires-Dist: local-attention>=1.10.0; extra == 'examples'
|
|
43
|
+
Requires-Dist: taylor-series-linear-attention; extra == 'examples'
|
|
43
44
|
Provides-Extra: test
|
|
44
45
|
Requires-Dist: pytest; extra == 'test'
|
|
45
46
|
Description-Content-Type: text/markdown
|
|
@@ -64,16 +65,28 @@ $ pip install titans-pytorch
|
|
|
64
65
|
import torch
|
|
65
66
|
from titans_pytorch import NeuralMemory
|
|
66
67
|
|
|
67
|
-
x = torch.randn(2, 64, 32)
|
|
68
|
-
|
|
69
68
|
mem = NeuralMemory(
|
|
70
|
-
dim =
|
|
71
|
-
chunk_size =
|
|
72
|
-
|
|
69
|
+
dim = 384,
|
|
70
|
+
chunk_size = 64,
|
|
71
|
+
pre_rmsnorm = True
|
|
72
|
+
).cuda()
|
|
73
|
+
|
|
74
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
75
|
+
retrieved = mem(seq)
|
|
73
76
|
|
|
74
|
-
|
|
77
|
+
assert seq.shape == retrieved.shape
|
|
78
|
+
```
|
|
75
79
|
|
|
76
|
-
|
|
80
|
+
## Experiments
|
|
81
|
+
|
|
82
|
+
```bash
|
|
83
|
+
$ pip install .[examples]
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
Then
|
|
87
|
+
|
|
88
|
+
```bash
|
|
89
|
+
$ python train.py
|
|
77
90
|
```
|
|
78
91
|
|
|
79
92
|
## Citations
|
|
@@ -18,16 +18,28 @@ $ pip install titans-pytorch
|
|
|
18
18
|
import torch
|
|
19
19
|
from titans_pytorch import NeuralMemory
|
|
20
20
|
|
|
21
|
-
x = torch.randn(2, 64, 32)
|
|
22
|
-
|
|
23
21
|
mem = NeuralMemory(
|
|
24
|
-
dim =
|
|
25
|
-
chunk_size =
|
|
26
|
-
|
|
22
|
+
dim = 384,
|
|
23
|
+
chunk_size = 64,
|
|
24
|
+
pre_rmsnorm = True
|
|
25
|
+
).cuda()
|
|
26
|
+
|
|
27
|
+
seq = torch.randn(2, 1024, 384).cuda()
|
|
28
|
+
retrieved = mem(seq)
|
|
27
29
|
|
|
28
|
-
|
|
30
|
+
assert seq.shape == retrieved.shape
|
|
31
|
+
```
|
|
29
32
|
|
|
30
|
-
|
|
33
|
+
## Experiments
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
$ pip install .[examples]
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
Then
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
$ python train.py
|
|
31
43
|
```
|
|
32
44
|
|
|
33
45
|
## Citations
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.9"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -37,7 +37,8 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
|
37
37
|
|
|
38
38
|
[project.optional-dependencies]
|
|
39
39
|
examples = [
|
|
40
|
-
"local-attention>=1.
|
|
40
|
+
"local-attention>=1.10.0",
|
|
41
|
+
"taylor-series-linear-attention"
|
|
41
42
|
]
|
|
42
43
|
test = [
|
|
43
44
|
"pytest"
|
|
@@ -55,6 +55,10 @@ def pack_one_with_inverse(t, pattern):
|
|
|
55
55
|
|
|
56
56
|
return packed, inverse
|
|
57
57
|
|
|
58
|
+
def softclamp_max(t, max_value):
|
|
59
|
+
range_value = max_value / 2
|
|
60
|
+
return ((t / range_value).tanh() * range_value) + range_value
|
|
61
|
+
|
|
58
62
|
# classes
|
|
59
63
|
|
|
60
64
|
class MLP(Module):
|
|
@@ -92,7 +96,8 @@ class NeuralMemory(Module):
|
|
|
92
96
|
chunk_size = 1,
|
|
93
97
|
model: Module | None = None,
|
|
94
98
|
store_memory_loss_fn: Callable = default_loss_fn,
|
|
95
|
-
pre_rmsnorm = False
|
|
99
|
+
pre_rmsnorm = False,
|
|
100
|
+
max_adaptive_step_size = 1e-5
|
|
96
101
|
):
|
|
97
102
|
super().__init__()
|
|
98
103
|
|
|
@@ -144,6 +149,8 @@ class NeuralMemory(Module):
|
|
|
144
149
|
Rearrange('... 1 -> ...')
|
|
145
150
|
)
|
|
146
151
|
|
|
152
|
+
self.max_adaptive_step_size = max_adaptive_step_size
|
|
153
|
+
|
|
147
154
|
# weight decay factor
|
|
148
155
|
|
|
149
156
|
self.to_decay_factor = nn.Sequential(
|
|
@@ -188,7 +195,7 @@ class NeuralMemory(Module):
|
|
|
188
195
|
|
|
189
196
|
batch = seq.shape[0]
|
|
190
197
|
|
|
191
|
-
adaptive_lr = self.to_adaptive_step(seq).
|
|
198
|
+
adaptive_lr = softclamp_max(self.to_adaptive_step(seq), self.max_adaptive_step_size)
|
|
192
199
|
|
|
193
200
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
194
201
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
@@ -304,7 +311,10 @@ class NeuralMemory(Module):
|
|
|
304
311
|
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
|
|
305
312
|
return_next_memories = False
|
|
306
313
|
):
|
|
307
|
-
batch = seq.shape[
|
|
314
|
+
batch, seq_len = seq.shape[:2]
|
|
315
|
+
|
|
316
|
+
if seq_len < self.chunk_size:
|
|
317
|
+
return torch.zeros_like(seq)
|
|
308
318
|
|
|
309
319
|
if exists(past_state):
|
|
310
320
|
past_state = tuple(TensorDict(d) for d in past_state)
|
|
@@ -4,12 +4,15 @@ import gzip
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
+
from torch import nn
|
|
7
8
|
from torch.optim import Adam
|
|
8
9
|
from torch.nn import functional as F
|
|
9
10
|
from torch.utils.data import DataLoader, Dataset
|
|
10
11
|
|
|
11
12
|
from local_attention import LocalTransformer
|
|
12
13
|
|
|
14
|
+
from taylor_series_linear_attention import TaylorSeriesLinearAttn
|
|
15
|
+
|
|
13
16
|
from titans_pytorch.titans import NeuralMemory
|
|
14
17
|
|
|
15
18
|
# constants
|
|
@@ -21,6 +24,7 @@ LEARNING_RATE = 2e-4
|
|
|
21
24
|
VALIDATE_EVERY = 100
|
|
22
25
|
GENERATE_EVERY = 500
|
|
23
26
|
GENERATE_LENGTH = 512
|
|
27
|
+
SHOULD_GENERATE = False
|
|
24
28
|
SEQ_LEN = 512
|
|
25
29
|
|
|
26
30
|
# helpers
|
|
@@ -38,13 +42,33 @@ def decode_tokens(tokens):
|
|
|
38
42
|
|
|
39
43
|
# instantiate GPT-like decoder model
|
|
40
44
|
|
|
45
|
+
titans_neural_memory = NeuralMemory(
|
|
46
|
+
dim = 384,
|
|
47
|
+
chunk_size = 64,
|
|
48
|
+
pre_rmsnorm = True
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
titans_neural_memory = nn.Sequential(
|
|
52
|
+
titans_neural_memory,
|
|
53
|
+
nn.RMSNorm(384)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
linear_attn = TaylorSeriesLinearAttn(
|
|
57
|
+
dim = 384,
|
|
58
|
+
dim_head = 16,
|
|
59
|
+
heads = 16,
|
|
60
|
+
causal = True
|
|
61
|
+
)
|
|
62
|
+
|
|
41
63
|
model = LocalTransformer(
|
|
42
64
|
num_tokens = 256,
|
|
43
|
-
dim =
|
|
65
|
+
dim = 384,
|
|
44
66
|
depth = 8,
|
|
45
67
|
causal = True,
|
|
46
68
|
local_attn_window_size = 64,
|
|
47
|
-
max_seq_len = SEQ_LEN
|
|
69
|
+
max_seq_len = SEQ_LEN,
|
|
70
|
+
global_attn_layer = titans_neural_memory,
|
|
71
|
+
layers_insert_global_attn = (4,)
|
|
48
72
|
).cuda()
|
|
49
73
|
|
|
50
74
|
# prepare enwik8 data
|
|
@@ -97,7 +121,7 @@ for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
|
97
121
|
loss = model(next(val_loader), return_loss = True)
|
|
98
122
|
print(f'validation loss: {loss.item()}')
|
|
99
123
|
|
|
100
|
-
if i % GENERATE_EVERY == 0:
|
|
124
|
+
if SHOULD_GENERATE and i % GENERATE_EVERY == 0:
|
|
101
125
|
model.eval()
|
|
102
126
|
inp = random.choice(val_dataset)[:-1]
|
|
103
127
|
prime = decode_tokens(inp)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|