titans-pytorch 0.0.20__tar.gz → 0.0.22__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/PKG-INFO +1 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/pyproject.toml +1 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/tests/test_titans.py +4 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/titans_pytorch/titans.py +11 -1
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/train.py +4 -3
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/.gitignore +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/LICENSE +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/README.md +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/data/README.md +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/fig1.png +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/fig2.png +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/requirements.txt +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.20 → titans_pytorch-0.0.22}/titans_pytorch/titans_attn_memory.py +0 -0
@@ -2,8 +2,10 @@ import torch
|
|
2
2
|
import pytest
|
3
3
|
|
4
4
|
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
5
|
+
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
5
6
|
def test_titans(
|
6
|
-
seq_len
|
7
|
+
seq_len,
|
8
|
+
max_grad_norm
|
7
9
|
):
|
8
10
|
|
9
11
|
from titans_pytorch import NeuralMemory
|
@@ -11,6 +13,7 @@ def test_titans(
|
|
11
13
|
mem = NeuralMemory(
|
12
14
|
dim = 384,
|
13
15
|
chunk_size = 64,
|
16
|
+
max_grad_norm = max_grad_norm
|
14
17
|
)
|
15
18
|
|
16
19
|
seq = torch.randn(2, seq_len, 384)
|
@@ -40,6 +40,9 @@ def exists(v):
|
|
40
40
|
def default(v, d):
|
41
41
|
return v if exists(v) else d
|
42
42
|
|
43
|
+
def identity(t):
|
44
|
+
return t
|
45
|
+
|
43
46
|
def round_down_multiple(seq, mult):
|
44
47
|
return seq // mult * mult
|
45
48
|
|
@@ -97,6 +100,9 @@ class MemoryMLP(Module):
|
|
97
100
|
|
98
101
|
# main neural memory
|
99
102
|
|
103
|
+
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-1):
|
104
|
+
return adaptive_step.sigmoid() * max_lr
|
105
|
+
|
100
106
|
def default_loss_fn(pred, target):
|
101
107
|
return (pred - target).pow(2).mean(dim = -1).sum()
|
102
108
|
|
@@ -109,6 +115,7 @@ class NeuralMemory(Module):
|
|
109
115
|
heads = 1,
|
110
116
|
model: Module | None = None,
|
111
117
|
store_memory_loss_fn: Callable = default_loss_fn,
|
118
|
+
adaptive_step_transform: Callable = default_adaptive_step_transform,
|
112
119
|
pre_rmsnorm = True,
|
113
120
|
post_rmsnorm = True,
|
114
121
|
max_grad_norm: float | None = None,
|
@@ -188,6 +195,8 @@ class NeuralMemory(Module):
|
|
188
195
|
Rearrange('b n h -> (b h) n')
|
189
196
|
)
|
190
197
|
|
198
|
+
self.adaptive_step_transform = adaptive_step_transform
|
199
|
+
|
191
200
|
# allow for softclamp the gradient norms for storing memories
|
192
201
|
|
193
202
|
self.max_grad_norm = max_grad_norm
|
@@ -242,7 +251,8 @@ class NeuralMemory(Module):
|
|
242
251
|
|
243
252
|
# pack batch and sequence dimension
|
244
253
|
|
245
|
-
adaptive_lr =
|
254
|
+
adaptive_lr = self.to_adaptive_step(seq)
|
255
|
+
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
|
246
256
|
|
247
257
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
248
258
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
@@ -62,11 +62,12 @@ def decode_tokens(tokens):
|
|
62
62
|
|
63
63
|
titans_neural_memory = NeuralMemory(
|
64
64
|
dim = 384,
|
65
|
-
chunk_size =
|
65
|
+
chunk_size = 4,
|
66
66
|
pre_rmsnorm = True,
|
67
67
|
post_rmsnorm = True,
|
68
|
-
dim_head =
|
69
|
-
heads =
|
68
|
+
dim_head = 64,
|
69
|
+
heads = 4,
|
70
|
+
max_grad_norm = 1.,
|
70
71
|
use_accelerated_scan = True,
|
71
72
|
default_mlp_kwargs = dict(
|
72
73
|
depth = NEURAL_MEMORY_DEPTH
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|