titans-pytorch 0.0.65__tar.gz → 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/PKG-INFO +6 -1
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/README.md +4 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/pyproject.toml +2 -1
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/tests/test_titans.py +3 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/titans_pytorch/titans.py +32 -7
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/.gitignore +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/LICENSE +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/assert_flex.py +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/data/README.md +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/data/enwik8.gz +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/fig1.png +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/fig2.png +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.0.65 → titans_pytorch-0.1.0}/train_mac.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -37,6 +37,7 @@ Requires-Python: >=3.9
|
|
|
37
37
|
Requires-Dist: accelerated-scan>=0.2.0
|
|
38
38
|
Requires-Dist: axial-positional-embedding>=0.3.5
|
|
39
39
|
Requires-Dist: einops>=0.8.0
|
|
40
|
+
Requires-Dist: einx>=0.3.0
|
|
40
41
|
Requires-Dist: hyper-connections>=0.1.8
|
|
41
42
|
Requires-Dist: ninja
|
|
42
43
|
Requires-Dist: rotary-embedding-torch
|
|
@@ -58,6 +59,10 @@ Description-Content-Type: text/markdown
|
|
|
58
59
|
|
|
59
60
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
60
61
|
|
|
62
|
+
## Appreciation
|
|
63
|
+
|
|
64
|
+
- [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
|
|
65
|
+
|
|
61
66
|
## Install
|
|
62
67
|
|
|
63
68
|
```bash
|
|
@@ -6,6 +6,10 @@
|
|
|
6
6
|
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
|
8
8
|
|
|
9
|
+
## Appreciation
|
|
10
|
+
|
|
11
|
+
- [@sentialx](https://github.com/sentialx) for sharing his early experimental results with me
|
|
12
|
+
|
|
9
13
|
## Install
|
|
10
14
|
|
|
11
15
|
```bash
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "titans-pytorch"
|
|
3
|
-
version = "0.0
|
|
3
|
+
version = "0.1.0"
|
|
4
4
|
description = "Titans"
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
|
|
@@ -28,6 +28,7 @@ dependencies = [
|
|
|
28
28
|
"accelerated-scan>=0.2.0",
|
|
29
29
|
"axial_positional_embedding>=0.3.5",
|
|
30
30
|
"einops>=0.8.0",
|
|
31
|
+
"einx>=0.3.0",
|
|
31
32
|
"hyper-connections>=0.1.8",
|
|
32
33
|
"Ninja",
|
|
33
34
|
"rotary-embedding-torch",
|
|
@@ -8,17 +8,20 @@ from titans_pytorch import NeuralMemory
|
|
|
8
8
|
@pytest.mark.parametrize('silu', (False, True))
|
|
9
9
|
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
|
|
10
10
|
@pytest.mark.parametrize('max_grad_norm', (None, 2.))
|
|
11
|
+
@pytest.mark.parametrize('per_parameter_lr_modulation', (False, True))
|
|
11
12
|
def test_titans(
|
|
12
13
|
seq_len,
|
|
13
14
|
silu,
|
|
14
15
|
learned_mem_model_weights,
|
|
15
16
|
max_grad_norm,
|
|
17
|
+
per_parameter_lr_modulation
|
|
16
18
|
):
|
|
17
19
|
mem = NeuralMemory(
|
|
18
20
|
dim = 384,
|
|
19
21
|
chunk_size = 64,
|
|
20
22
|
activation = nn.SiLU() if silu else None,
|
|
21
23
|
max_grad_norm = max_grad_norm,
|
|
24
|
+
per_parameter_lr_modulation = per_parameter_lr_modulation,
|
|
22
25
|
learned_mem_model_weights = learned_mem_model_weights
|
|
23
26
|
)
|
|
24
27
|
|
|
@@ -17,6 +17,7 @@ from titans_pytorch.associative_scan import (
|
|
|
17
17
|
pad_at_dim
|
|
18
18
|
)
|
|
19
19
|
|
|
20
|
+
import einx
|
|
20
21
|
from einops import rearrange, repeat, pack, unpack
|
|
21
22
|
from einops.layers.torch import Rearrange, Reduce
|
|
22
23
|
|
|
@@ -26,6 +27,7 @@ b - batch
|
|
|
26
27
|
n - sequence
|
|
27
28
|
d - feature dimension
|
|
28
29
|
c - intra-chunk
|
|
30
|
+
w - num memory network weight parameters
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
33
|
LinearNoBias = partial(Linear, bias = False)
|
|
@@ -220,6 +222,8 @@ class NeuralMemory(Module):
|
|
|
220
222
|
store_memory_loss_fn: Callable = default_loss_fn,
|
|
221
223
|
adaptive_step_transform: Callable | None = None,
|
|
222
224
|
default_step_transform_max_lr = 1e-2,
|
|
225
|
+
per_parameter_lr_modulation = False, # allow outer network to control learning rate per weight matrix of memory network
|
|
226
|
+
max_mem_layer_modulation = 1e1, # max of 10.
|
|
223
227
|
pre_rmsnorm = True,
|
|
224
228
|
post_rmsnorm = True,
|
|
225
229
|
learned_mem_model_weights = True,
|
|
@@ -250,7 +254,7 @@ class NeuralMemory(Module):
|
|
|
250
254
|
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
|
251
255
|
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
|
|
252
256
|
|
|
253
|
-
self.retrieve_gate =
|
|
257
|
+
self.retrieve_gate = Sequential(
|
|
254
258
|
LinearNoBias(dim, heads),
|
|
255
259
|
Rearrange('b n h -> b h n 1'),
|
|
256
260
|
nn.Sigmoid()
|
|
@@ -270,6 +274,8 @@ class NeuralMemory(Module):
|
|
|
270
274
|
|
|
271
275
|
self.memory_model = model
|
|
272
276
|
|
|
277
|
+
self.num_memory_parameter_tensors = len(set(model.parameters()))
|
|
278
|
+
|
|
273
279
|
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
|
|
274
280
|
|
|
275
281
|
self.chunk_size = chunk_size
|
|
@@ -299,15 +305,14 @@ class NeuralMemory(Module):
|
|
|
299
305
|
nn.init.normal_(self.empty_memory_embed, std = 0.02)
|
|
300
306
|
|
|
301
307
|
# learned adaptive learning rate and momentum
|
|
302
|
-
# todo - explore mlp layerwise learned lr / momentum
|
|
303
308
|
|
|
304
|
-
self.to_momentum =
|
|
309
|
+
self.to_momentum = Sequential(
|
|
305
310
|
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
306
311
|
LinearNoBias(dim, heads),
|
|
307
312
|
Rearrange('b n h -> (b h) n 1')
|
|
308
313
|
)
|
|
309
314
|
|
|
310
|
-
self.to_adaptive_step =
|
|
315
|
+
self.to_adaptive_step = Sequential(
|
|
311
316
|
LinearNoBias(dim, heads),
|
|
312
317
|
Rearrange('b n h -> (b h) n')
|
|
313
318
|
)
|
|
@@ -317,13 +322,24 @@ class NeuralMemory(Module):
|
|
|
317
322
|
|
|
318
323
|
self.adaptive_step_transform = adaptive_step_transform
|
|
319
324
|
|
|
325
|
+
# per layer learning rate modulation
|
|
326
|
+
|
|
327
|
+
self.to_layer_modulation = Sequential(
|
|
328
|
+
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
329
|
+
LinearNoBias(dim, heads * self.num_memory_parameter_tensors),
|
|
330
|
+
Rearrange('b n (h w) -> w (b h) n', h = heads),
|
|
331
|
+
nn.Sigmoid()
|
|
332
|
+
) if per_parameter_lr_modulation else None
|
|
333
|
+
|
|
334
|
+
self.max_mem_layer_modulation = max_mem_layer_modulation
|
|
335
|
+
|
|
320
336
|
# allow for softclamp the gradient norms for storing memories
|
|
321
337
|
|
|
322
338
|
self.max_grad_norm = max_grad_norm
|
|
323
339
|
|
|
324
340
|
# weight decay factor
|
|
325
341
|
|
|
326
|
-
self.to_decay_factor =
|
|
342
|
+
self.to_decay_factor = Sequential(
|
|
327
343
|
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size),
|
|
328
344
|
LinearNoBias(dim, heads),
|
|
329
345
|
Rearrange('b n h -> (b h) n 1')
|
|
@@ -387,6 +403,11 @@ class NeuralMemory(Module):
|
|
|
387
403
|
adaptive_momentum = self.to_momentum(seq).sigmoid()
|
|
388
404
|
decay_factor = self.to_decay_factor(seq).sigmoid()
|
|
389
405
|
|
|
406
|
+
need_layer_lr_mod = exists(self.to_layer_modulation)
|
|
407
|
+
|
|
408
|
+
if need_layer_lr_mod:
|
|
409
|
+
layer_lr_mod = self.to_layer_modulation(seq) * self.max_mem_layer_modulation
|
|
410
|
+
|
|
390
411
|
# keys and values
|
|
391
412
|
|
|
392
413
|
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
|
|
@@ -418,6 +439,11 @@ class NeuralMemory(Module):
|
|
|
418
439
|
|
|
419
440
|
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
|
|
420
441
|
|
|
442
|
+
# maybe per layer modulation
|
|
443
|
+
|
|
444
|
+
if need_layer_lr_mod:
|
|
445
|
+
grads = TensorDict({name: einx.multiply('b h, b h ... -> b h ...', layer_lr_mod, t) for layer_lr_mod, (name, t) in zip(layer_lr_mod, grads.items())})
|
|
446
|
+
|
|
421
447
|
# negative gradients, adaptive lr already applied as loss weight
|
|
422
448
|
|
|
423
449
|
surprises = grads.apply(lambda t: -t)
|
|
@@ -469,7 +495,7 @@ class NeuralMemory(Module):
|
|
|
469
495
|
|
|
470
496
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
|
471
497
|
|
|
472
|
-
update = scan_fn(1. - decay_factor, momentum)
|
|
498
|
+
update = scan_fn(1. - decay_factor, momentum)
|
|
473
499
|
|
|
474
500
|
updates[param_name] = inverse_pack(update)
|
|
475
501
|
next_momentum[param_name] = inverse_pack(momentum)
|
|
@@ -580,7 +606,6 @@ class NeuralMemory(Module):
|
|
|
580
606
|
|
|
581
607
|
past_weights, _ = past_state
|
|
582
608
|
|
|
583
|
-
|
|
584
609
|
retrieved = self.retrieve_memories(seq, past_weights + updates)
|
|
585
610
|
|
|
586
611
|
if not return_aux_kv_loss:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|