titans-pytorch 0.1.2__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/PKG-INFO +2 -1
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/pyproject.toml +2 -1
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/tests/test_titans.py +3 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/titans_pytorch/__init__.py +2 -1
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/titans_pytorch/titans.py +85 -5
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/train_mac.py +5 -3
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/.gitignore +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/LICENSE +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/README.md +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/data/README.md +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/data/enwik8.gz +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/fig1.png +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/fig2.png +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.1.2 → titans_pytorch-0.1.6}/titans_pytorch/mac_transformer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -46,6 +46,7 @@ Requires-Dist: torch>=2.2
|
|
|
46
46
|
Requires-Dist: tqdm
|
|
47
47
|
Requires-Dist: x-transformers
|
|
48
48
|
Provides-Extra: examples
|
|
49
|
+
Requires-Dist: adam-atan2-pytorch>=0.1.18; extra == 'examples'
|
|
49
50
|
Requires-Dist: wandb; extra == 'examples'
|
|
50
51
|
Provides-Extra: test
|
|
51
52
|
Requires-Dist: pytest; extra == 'test'
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.1.
|
|
3
|
+
version = "0.1.6"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -45,6 +45,7 @@ Repository = "https://github.com/lucidrains/titans-pytorch"
|
|
|
45
45
|
[project.optional-dependencies]
|
|
46
46
|
|
|
47
47
|
examples = [
|
|
48
|
+
"adam-atan2-pytorch>=0.1.18",
|
|
48
49
|
"wandb"
|
|
49
50
|
]
|
|
50
51
|
|
|
@@ -11,12 +11,14 @@ def exists(v):
|
|
|
11
11
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
|
12
12
|
@pytest.mark.parametrize('silu', (False, True))
|
|
13
13
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
14
|
+
@pytest.mark.parametrize('attn_pool_chunks', (False, True))
|
|
14
15
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
15
16
|
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
|
16
17
|
def test_titans(
|
|
17
18
|
seq_len,
|
|
18
19
|
silu,
|
|
19
20
|
learned_mem_model_weights,
|
|
21
|
+
attn_pool_chunks,
|
|
20
22
|
max_grad_norm,
|
|
21
23
|
per_parameter_lr_modulation
|
|
22
24
|
):
|
|
@@ -24,6 +26,7 @@ def test_titans(
|
|
|
24
26
|
dim = 384,
|
|
25
27
|
chunk_size = 64,
|
|
26
28
|
activation = nn.SiLU() if silu else None,
|
|
29
|
+
attn_pool_chunks = attn_pool_chunks,
|
|
27
30
|
max_grad_norm = max_grad_norm,
|
|
28
31
|
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
|
29
32
|
learned_mem_model_weights = learned_mem_model_weights
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
from typing import Callable
|
|
3
|
+
|
|
3
4
|
import math
|
|
4
5
|
from functools import partial
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
|
-
from torch import nn, Tensor
|
|
8
|
+
from torch import nn, cat, Tensor
|
|
8
9
|
import torch.nn.functional as F
|
|
9
10
|
from torch.nn import Linear, Module, Parameter, ParameterList
|
|
10
11
|
from torch.func import functional_call, vmap, grad
|
|
@@ -18,7 +19,7 @@ from titans_pytorch.associative_scan import (
|
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
import einx
|
|
21
|
-
from einops import rearrange, repeat, pack, unpack
|
|
22
|
+
from einops import rearrange, repeat, reduce, pack, unpack
|
|
22
23
|
from einops.layers.torch import Rearrange, Reduce
|
|
23
24
|
|
|
24
25
|
"""
|
|
@@ -95,6 +96,37 @@ class MultiheadRMSNorm(Module):
|
|
|
95
96
|
def forward(self, x):
|
|
96
97
|
return self.rmsnorm(x) * (self.gamma + 1.)
|
|
97
98
|
|
|
99
|
+
# attention pool
|
|
100
|
+
|
|
101
|
+
class AttentionPool(Module):
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
dim,
|
|
105
|
+
chunk_size
|
|
106
|
+
):
|
|
107
|
+
"""
|
|
108
|
+
taken from Enformer https://www.nature.com/articles/s41592-021-01252-x , in turn taken from somewhere else
|
|
109
|
+
"""
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.split_chunks = Rearrange('b (n c) d -> b n c d', c = chunk_size)
|
|
112
|
+
self.to_attn_logits = nn.Linear(dim, dim)
|
|
113
|
+
|
|
114
|
+
# default to average pool
|
|
115
|
+
|
|
116
|
+
nn.init.zeros_(self.to_attn_logits.weight)
|
|
117
|
+
nn.init.zeros_(self.to_attn_logits.bias)
|
|
118
|
+
|
|
119
|
+
def forward(
|
|
120
|
+
self,
|
|
121
|
+
x
|
|
122
|
+
):
|
|
123
|
+
x = self.split_chunks(x)
|
|
124
|
+
attn_logits = self.to_attn_logits(x)
|
|
125
|
+
|
|
126
|
+
attn = attn_logits.softmax(dim = -2)
|
|
127
|
+
|
|
128
|
+
return reduce(x * attn, 'b n c d -> b n d', 'sum')
|
|
129
|
+
|
|
98
130
|
# classes
|
|
99
131
|
|
|
100
132
|
class MemoryMLP(Module):
|
|
@@ -123,6 +155,46 @@ class MemoryMLP(Module):
|
|
|
123
155
|
|
|
124
156
|
return x
|
|
125
157
|
|
|
158
|
+
# memory mlp, but with gated residual + final projection
|
|
159
|
+
|
|
160
|
+
class GatedResidualMemoryMLP(Module):
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
dim,
|
|
164
|
+
depth
|
|
165
|
+
):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.depth = depth
|
|
168
|
+
|
|
169
|
+
self.weights = ParameterList([
|
|
170
|
+
ParameterList([
|
|
171
|
+
Parameter(torch.randn(dim, dim)),
|
|
172
|
+
Parameter(torch.randn(dim * 2, dim)),
|
|
173
|
+
]) for _ in range(depth)
|
|
174
|
+
])
|
|
175
|
+
|
|
176
|
+
self.final_proj = Parameter(torch.randn(dim, dim))
|
|
177
|
+
|
|
178
|
+
for param in self.parameters():
|
|
179
|
+
nn.init.xavier_uniform_(param)
|
|
180
|
+
|
|
181
|
+
def forward(
|
|
182
|
+
self,
|
|
183
|
+
x
|
|
184
|
+
):
|
|
185
|
+
for weight, to_gates in self.weights:
|
|
186
|
+
res = x
|
|
187
|
+
|
|
188
|
+
x = x @ weight
|
|
189
|
+
x = F.silu(x)
|
|
190
|
+
|
|
191
|
+
# gated residual
|
|
192
|
+
|
|
193
|
+
gates = cat((x, res), dim = -1) @ to_gates
|
|
194
|
+
x = res.lerp(x, gates.sigmoid())
|
|
195
|
+
|
|
196
|
+
return x @ self.final_proj
|
|
197
|
+
|
|
126
198
|
# memory mlp with factorized weights
|
|
127
199
|
# so can tradeoff capacity for smaller chunk sizes
|
|
128
200
|
|
|
@@ -224,6 +296,7 @@ class NeuralMemory(Module):
|
|
|
224
296
|
default_step_transform_max_lr = 1e-2,
|
|
225
297
|
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
226
298
|
max_mem_layer_modulation = 1e1, # max of 10.
|
|
299
|
+
attn_pool_chunks = False,
|
|
227
300
|
pre_rmsnorm = True,
|
|
228
301
|
post_rmsnorm = True,
|
|
229
302
|
learned_mem_model_weights = True,
|
|
@@ -304,10 +377,17 @@ class NeuralMemory(Module):
|
|
|
304
377
|
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
|
305
378
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
306
379
|
|
|
380
|
+
# whether to use averaging of chunks, or attention pooling
|
|
381
|
+
|
|
382
|
+
if not attn_pool_chunks:
|
|
383
|
+
chunk_reduce_module = Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size)
|
|
384
|
+
else:
|
|
385
|
+
chunk_reduce_module = AttentionPool(dim, chunk_size = chunk_size)
|
|
386
|
+
|
|
307
387
|
# learned adaptive learning rate and momentum
|
|
308
388
|
|
|
309
389
|
self.to_momentum = Sequential(
|
|
310
|
-
|
|
390
|
+
chunk_reduce_module,
|
|
311
391
|
LinearNoBias(dim, heads),
|
|
312
392
|
Rearrange('b n h -> (b h) n 1')
|
|
313
393
|
)
|
|
@@ -325,7 +405,7 @@ class NeuralMemory(Module):
|
|
|
325
405
|
# per layer learning rate modulation
|
|
326
406
|
|
|
327
407
|
self.to_layer_modulation = Sequential(
|
|
328
|
-
|
|
408
|
+
chunk_reduce_module,
|
|
329
409
|
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
330
410
|
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
331
411
|
nn.Sigmoid()
|
|
@@ -340,7 +420,7 @@ class NeuralMemory(Module):
|
|
|
340
420
|
# weight decay factor
|
|
341
421
|
|
|
342
422
|
self.to_decay_factor = Sequential(
|
|
343
|
-
|
|
423
|
+
chunk_reduce_module,
|
|
344
424
|
LinearNoBias(dim, heads),
|
|
345
425
|
Rearrange('b n h -> (b h) n 1')
|
|
346
426
|
)
|
|
@@ -5,10 +5,10 @@ import numpy as np
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch import nn, Tensor
|
|
8
|
-
from torch.optim import Adam
|
|
9
8
|
from torch.nn import functional as F
|
|
10
9
|
from torch.utils.data import DataLoader, Dataset
|
|
11
10
|
|
|
11
|
+
from adam_atan2_pytorch import AdoptAtan2
|
|
12
12
|
from titans_pytorch import MemoryAsContextTransformer
|
|
13
13
|
|
|
14
14
|
# constants
|
|
@@ -34,6 +34,7 @@ NEURAL_MEM_GATE_ATTN_OUTPUT = True
|
|
|
34
34
|
WINDOW_SIZE = 32
|
|
35
35
|
NEURAL_MEM_SEGMENT_LEN = WINDOW_SIZE // 2 # set smaller for more granularity for learning rate / momentum etc
|
|
36
36
|
SLIDING_WINDOWS = True
|
|
37
|
+
STORE_ATTN_POOL_CHUNKS = True # whether to use attention pooling for chunk derived momentum, per-layer lr mod, decay
|
|
37
38
|
KV_RECON_LOSS_WEIGHT = 0.
|
|
38
39
|
LEARNED_MEM_MODEL_WEIGHTS = True
|
|
39
40
|
|
|
@@ -86,6 +87,7 @@ model = MemoryAsContextTransformer(
|
|
|
86
87
|
neural_memory_kwargs = dict(
|
|
87
88
|
dim_head = 64,
|
|
88
89
|
heads = 4,
|
|
90
|
+
attn_pool_chunks = STORE_ATTN_POOL_CHUNKS,
|
|
89
91
|
use_accelerated_scan = USE_ACCELERATED_SCAN,
|
|
90
92
|
learned_mem_model_weights = LEARNED_MEM_MODEL_WEIGHTS,
|
|
91
93
|
default_model_kwargs = dict(
|
|
@@ -122,11 +124,11 @@ val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
|
|
122
124
|
|
|
123
125
|
# optimizer
|
|
124
126
|
|
|
125
|
-
optim =
|
|
127
|
+
optim = AdoptAtan2(model.parameters(), lr = LEARNING_RATE)
|
|
126
128
|
|
|
127
129
|
# training
|
|
128
130
|
|
|
129
|
-
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
|
|
131
|
+
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval = 10., desc = 'training'):
|
|
130
132
|
model.train()
|
|
131
133
|
|
|
132
134
|
for __ in range(GRADIENT_ACCUMULATE_EVERY):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|