titans-pytorch 0.0.21__tar.gz → 0.0.23__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
Potentially problematic release.
This version of titans-pytorch might be problematic. Click here for more details.
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/PKG-INFO +1 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/pyproject.toml +1 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/tests/test_titans.py +4 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/titans_pytorch/titans.py +23 -10
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/train.py +2 -1
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/.gitignore +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/LICENSE +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/README.md +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/data/README.md +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/fig1.png +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/fig2.png +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/requirements.txt +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.21 → titans_pytorch-0.0.23}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -2,8 +2,10 @@ import torch
|
|
2
2
|
import pytest
|
3
3
|
|
4
4
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
5
|
+
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
5
6
|
def test_titans(
|
6
|
-
seq_len
|
7
|
+
seq_len,
|
8
|
+
max_grad_norm
|
7
9
|
):
|
8
10
|
|
9
11
|
from titans_pytorch import NeuralMemory
|
@@ -11,6 +13,7 @@ def test_titans(
|
|
11
13
|
mem = NeuralMemory(
|
12
14
|
dim = 384,
|
13
15
|
chunk_size = 64,
|
16
|
+
max_grad_norm = max_grad_norm
|
14
17
|
)
|
15
18
|
|
16
19
|
seq = torch.randn(2, seq_len, 384)
|
@@ -100,11 +100,11 @@ class MemoryMLP(Module):
|
|
100
100
|
|
101
101
|
# main neural memory
|
102
102
|
|
103
|
-
def default_adaptive_step_transform(adaptive_step):
|
104
|
-
return
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
|
104
|
+
return adaptive_step.sigmoid() * max_lr
|
105
105
|
|
106
106
|
def default_loss_fn(pred, target):
|
107
|
-
return (pred - target).pow(2).mean(dim = -1)
|
107
|
+
return (pred - target).pow(2).mean(dim = -1)
|
108
108
|
|
109
109
|
class NeuralMemory(Module):
|
110
110
|
def __init__(
|
@@ -142,6 +142,12 @@ class NeuralMemory(Module):
|
|
142
142
|
self.merge_heads = Rearrange('(b h) n d -> b n (h d)', h = heads)
|
143
143
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
144
144
|
|
145
|
+
self.retrieve_gate = nn.Sequential(
|
146
|
+
LinearNoBias(dim, heads),
|
147
|
+
Rearrange('b n h -> (b h) n 1'),
|
148
|
+
nn.Sigmoid()
|
149
|
+
) if heads > 1 else None
|
150
|
+
|
145
151
|
# memory mlp
|
146
152
|
|
147
153
|
if not exists(model):
|
@@ -159,12 +165,13 @@ class NeuralMemory(Module):
|
|
159
165
|
|
160
166
|
# prepare function for per sample gradients from model above, using torch.func
|
161
167
|
|
162
|
-
def forward_and_loss(params, inputs, target):
|
168
|
+
def forward_and_loss(params, inputs, loss_weights, target):
|
163
169
|
pred = functional_call(self.memory_model, params, inputs)
|
164
170
|
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
|
165
|
-
|
171
|
+
loss = loss * loss_weights
|
172
|
+
return loss.sum()
|
166
173
|
|
167
|
-
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0))
|
174
|
+
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
|
168
175
|
|
169
176
|
# queries for retrieving from the model
|
170
177
|
|
@@ -190,7 +197,6 @@ class NeuralMemory(Module):
|
|
190
197
|
)
|
191
198
|
|
192
199
|
self.to_adaptive_step = nn.Sequential(
|
193
|
-
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
194
200
|
LinearNoBias(dim, heads),
|
195
201
|
Rearrange('b n h -> (b h) n')
|
196
202
|
)
|
@@ -271,9 +277,11 @@ class NeuralMemory(Module):
|
|
271
277
|
|
272
278
|
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
|
273
279
|
|
280
|
+
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
|
281
|
+
|
274
282
|
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
|
275
283
|
|
276
|
-
grads = self.per_sample_grad_fn(dict(curr_weights), keys, values)
|
284
|
+
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
|
277
285
|
|
278
286
|
grads = TensorDict(grads)
|
279
287
|
|
@@ -286,9 +294,9 @@ class NeuralMemory(Module):
|
|
286
294
|
|
287
295
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
288
296
|
|
289
|
-
#
|
297
|
+
# negative gradients, adaptive lr already applied as loss weight
|
290
298
|
|
291
|
-
surprises = grads.apply(lambda t:
|
299
|
+
surprises = grads.apply(lambda t: -t)
|
292
300
|
|
293
301
|
# determine scan function
|
294
302
|
|
@@ -405,6 +413,11 @@ class NeuralMemory(Module):
|
|
405
413
|
|
406
414
|
values = rearrange(values, '(b n) c d -> b (n c) d', b = batch)
|
407
415
|
|
416
|
+
# maybe gate
|
417
|
+
|
418
|
+
if exists(self.retrieve_gate):
|
419
|
+
values = values * self.retrieve_gate(seq)
|
420
|
+
|
408
421
|
# maybe merge heads and combine
|
409
422
|
|
410
423
|
values = self.merge_heads(values)
|
@@ -27,7 +27,7 @@ LEARNING_RATE = 2e-4
|
|
27
27
|
VALIDATE_EVERY = 100
|
28
28
|
GENERATE_EVERY = 500
|
29
29
|
GENERATE_LENGTH = 512
|
30
|
-
SHOULD_GENERATE =
|
30
|
+
SHOULD_GENERATE = True
|
31
31
|
SEQ_LEN = 512
|
32
32
|
|
33
33
|
PROJECT_NAME = 'titans-neural-memory'
|
@@ -67,6 +67,7 @@ titans_neural_memory = NeuralMemory(
|
|
67
67
|
post_rmsnorm = True,
|
68
68
|
dim_head = 64,
|
69
69
|
heads = 4,
|
70
|
+
max_grad_norm = 1.,
|
70
71
|
use_accelerated_scan = True,
|
71
72
|
default_mlp_kwargs = dict(
|
72
73
|
depth = NEURAL_MEMORY_DEPTH
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|