titans-pytorch 0.0.18__tar.gz → 0.0.20__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/PKG-INFO +1 -1
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/pyproject.toml +1 -1
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/tests/test_titans.py +6 -2
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/titans_pytorch/titans.py +39 -5
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/.gitignore +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/LICENSE +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/README.md +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/data/README.md +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/fig1.png +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/fig2.png +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/requirements.txt +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/titans_pytorch/titans_attn_memory.py +0 -0
- {titans_pytorch-0.0.18 → titans_pytorch-0.0.20}/train.py +0 -0
@@ -1,7 +1,11 @@
|
|
1
1
|
import torch
|
2
2
|
import pytest
|
3
3
|
|
4
|
-
|
4
|
+
@pytest.mark.parametrize('seq_len', (32, 1024, 77))
|
5
|
+
def test_titans(
|
6
|
+
seq_len
|
7
|
+
):
|
8
|
+
|
5
9
|
from titans_pytorch import NeuralMemory
|
6
10
|
|
7
11
|
mem = NeuralMemory(
|
@@ -9,7 +13,7 @@ def test_titans():
|
|
9
13
|
chunk_size = 64,
|
10
14
|
)
|
11
15
|
|
12
|
-
seq = torch.randn(2,
|
16
|
+
seq = torch.randn(2, seq_len, 384)
|
13
17
|
retrieved = mem(seq)
|
14
18
|
|
15
19
|
assert seq.shape == retrieved.shape
|
@@ -17,7 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
17
17
|
)
|
18
18
|
|
19
19
|
import einx
|
20
|
-
from einops import rearrange, pack, unpack
|
20
|
+
from einops import rearrange, repeat, pack, unpack
|
21
21
|
from einops.layers.torch import Rearrange, Reduce
|
22
22
|
|
23
23
|
"""
|
@@ -55,6 +55,21 @@ def pack_one_with_inverse(t, pattern):
|
|
55
55
|
|
56
56
|
return packed, inverse
|
57
57
|
|
58
|
+
# softclamping gradients
|
59
|
+
|
60
|
+
def softclamp_max(t, max_value):
|
61
|
+
half_max_value = max_value / 2
|
62
|
+
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
|
63
|
+
|
64
|
+
def softclamp_grad_norm(t, max_value):
|
65
|
+
t, inverse = pack_one_with_inverse(t, 'bn *')
|
66
|
+
|
67
|
+
norm = t.norm(dim = -1, keepdim = True)
|
68
|
+
clamped_norm = softclamp_max(norm, max_value)
|
69
|
+
|
70
|
+
t = t * (clamped_norm / norm)
|
71
|
+
return inverse(t)
|
72
|
+
|
58
73
|
# classes
|
59
74
|
|
60
75
|
class MemoryMLP(Module):
|
@@ -96,6 +111,7 @@ class NeuralMemory(Module):
|
|
96
111
|
store_memory_loss_fn: Callable = default_loss_fn,
|
97
112
|
pre_rmsnorm = True,
|
98
113
|
post_rmsnorm = True,
|
114
|
+
max_grad_norm: float | None = None,
|
99
115
|
use_accelerated_scan = False,
|
100
116
|
default_mlp_kwargs: dict = dict(
|
101
117
|
depth = 4
|
@@ -152,6 +168,11 @@ class NeuralMemory(Module):
|
|
152
168
|
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
153
169
|
self.store_memory_loss_fn = store_memory_loss_fn
|
154
170
|
|
171
|
+
# empty memory embed
|
172
|
+
|
173
|
+
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
|
174
|
+
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
175
|
+
|
155
176
|
# learned adaptive learning rate and momentum
|
156
177
|
# todo - explore mlp layerwise learned lr / momentum
|
157
178
|
|
@@ -167,6 +188,10 @@ class NeuralMemory(Module):
|
|
167
188
|
Rearrange('b n h -> (b h) n')
|
168
189
|
)
|
169
190
|
|
191
|
+
# allow for softclamp the gradient norms for storing memories
|
192
|
+
|
193
|
+
self.max_grad_norm = max_grad_norm
|
194
|
+
|
170
195
|
# weight decay factor
|
171
196
|
|
172
197
|
self.to_decay_factor = nn.Sequential(
|
@@ -187,6 +212,9 @@ class NeuralMemory(Module):
|
|
187
212
|
|
188
213
|
return init_weights, init_momentum
|
189
214
|
|
215
|
+
def init_empty_memory_embed(self, batch, seq_len):
|
216
|
+
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
|
217
|
+
|
190
218
|
def store_memories(
|
191
219
|
self,
|
192
220
|
seq,
|
@@ -239,6 +267,11 @@ class NeuralMemory(Module):
|
|
239
267
|
|
240
268
|
grads = TensorDict(grads)
|
241
269
|
|
270
|
+
# maybe softclamp grad norm
|
271
|
+
|
272
|
+
if exists(self.max_grad_norm):
|
273
|
+
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
|
274
|
+
|
242
275
|
# restore batch and sequence dimension
|
243
276
|
|
244
277
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
@@ -372,11 +405,12 @@ class NeuralMemory(Module):
|
|
372
405
|
|
373
406
|
values = self.post_rmsnorm(values)
|
374
407
|
|
375
|
-
# restore
|
408
|
+
# restore, pad with empty memory embed
|
376
409
|
|
377
|
-
|
378
|
-
values = values
|
410
|
+
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
|
411
|
+
values = torch.cat((empty_memory_embeds, values), dim = -2)
|
379
412
|
|
413
|
+
values = values[:, :-padding]
|
380
414
|
return values
|
381
415
|
|
382
416
|
def forward(
|
@@ -389,7 +423,7 @@ class NeuralMemory(Module):
|
|
389
423
|
batch, seq_len = seq.shape[:2]
|
390
424
|
|
391
425
|
if seq_len < self.chunk_size:
|
392
|
-
return
|
426
|
+
return self.init_empty_memory_embed(batch, seq_len)
|
393
427
|
|
394
428
|
if exists(past_state):
|
395
429
|
past_state = tuple(TensorDict(d) for d in past_state)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|