titans-pytorch 0.3.9__tar.gz → 0.3.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/PKG-INFO +1 -1
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/pyproject.toml +1 -1
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/tests/test_titans.py +33 -5
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/titans_pytorch/neural_memory.py +57 -18
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/.gitignore +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/LICENSE +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/README.md +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/data/README.md +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/fig1.png +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/fig2.png +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.9 → titans_pytorch-0.3.10}/train_mac.py +0 -0
@@ -60,6 +60,26 @@ def test_titans(
|
|
60
60
|
|
61
61
|
assert seq.shape == retrieved.shape
|
62
62
|
|
63
|
+
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
|
64
|
+
def test_titans_second_order_momentum(
|
65
|
+
learned_momentum_combine
|
66
|
+
):
|
67
|
+
|
68
|
+
mem = NeuralMemory(
|
69
|
+
dim = 384,
|
70
|
+
dim_head = 64,
|
71
|
+
heads = 2,
|
72
|
+
chunk_size = 1,
|
73
|
+
batch_size = 2,
|
74
|
+
momentum_order = 2,
|
75
|
+
learned_momentum_combine = learned_momentum_combine
|
76
|
+
)
|
77
|
+
|
78
|
+
seq = torch.randn(2, 5, 384)
|
79
|
+
|
80
|
+
parallel_retrieved, state = mem(seq)
|
81
|
+
assert seq.shape == parallel_retrieved.shape
|
82
|
+
|
63
83
|
def test_titans_attn_memory():
|
64
84
|
from titans_pytorch.memory_models import MemoryAttention
|
65
85
|
|
@@ -318,12 +338,16 @@ def test_flex(
|
|
318
338
|
|
319
339
|
assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)
|
320
340
|
|
321
|
-
@
|
322
|
-
def test_assoc_scan(
|
341
|
+
@pytest.mark.parametrize('use_accelerated', (True, False))
|
342
|
+
def test_assoc_scan(
|
343
|
+
use_accelerated
|
344
|
+
):
|
323
345
|
from titans_pytorch.neural_memory import AssocScan
|
324
|
-
torch.set_default_dtype(torch.float64)
|
325
346
|
|
326
|
-
|
347
|
+
if use_accelerated and not torch.cuda.is_available():
|
348
|
+
pytest.skip()
|
349
|
+
|
350
|
+
scan = AssocScan(use_accelerated = use_accelerated)
|
327
351
|
|
328
352
|
seq_len = 128
|
329
353
|
mid_point = seq_len // 2
|
@@ -331,6 +355,10 @@ def test_assoc_scan():
|
|
331
355
|
gates = torch.randn(2, seq_len, 16).sigmoid()
|
332
356
|
inputs = torch.randn(2, seq_len, 16)
|
333
357
|
|
358
|
+
if use_accelerated:
|
359
|
+
gates = gates.cuda()
|
360
|
+
inputs = inputs.cuda()
|
361
|
+
|
334
362
|
output = scan(gates, inputs)
|
335
363
|
|
336
364
|
gates1, gates2 = gates[:, :mid_point], gates[:, mid_point:]
|
@@ -341,4 +369,4 @@ def test_assoc_scan():
|
|
341
369
|
second_half = scan(gates2, inputs2, prev = first_half[:, -1])
|
342
370
|
assert second_half.shape == inputs2.shape
|
343
371
|
|
344
|
-
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-
|
372
|
+
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
|
@@ -3,6 +3,7 @@ from typing import Callable
|
|
3
3
|
|
4
4
|
import math
|
5
5
|
from functools import partial
|
6
|
+
from itertools import zip_longest
|
6
7
|
from collections import namedtuple
|
7
8
|
|
8
9
|
import torch
|
@@ -21,16 +22,19 @@ from titans_pytorch.memory_models import(
|
|
21
22
|
)
|
22
23
|
|
23
24
|
import einx
|
24
|
-
from einops import rearrange, repeat, reduce, pack, unpack
|
25
|
+
from einops import einsum, rearrange, repeat, reduce, pack, unpack
|
25
26
|
from einops.layers.torch import Rearrange, Reduce
|
26
27
|
|
27
28
|
"""
|
28
29
|
ein notation:
|
29
30
|
b - batch
|
31
|
+
h - heads
|
32
|
+
bh - batch and heads
|
30
33
|
n - sequence
|
31
34
|
d - feature dimension
|
32
35
|
c - intra-chunk
|
33
36
|
w - num memory network weight parameters
|
37
|
+
o - momentum orders
|
34
38
|
"""
|
35
39
|
|
36
40
|
LinearNoBias = partial(Linear, bias = False)
|
@@ -224,6 +228,8 @@ class NeuralMemory(Module):
|
|
224
228
|
per_head_learned_parameters = True,
|
225
229
|
attn_pool_chunks = False,
|
226
230
|
momentum = True,
|
231
|
+
momentum_order = 1,
|
232
|
+
learned_momentum_combine = False,
|
227
233
|
pre_rmsnorm = True,
|
228
234
|
post_rmsnorm = False,
|
229
235
|
qk_rmsnorm = False,
|
@@ -367,12 +373,7 @@ class NeuralMemory(Module):
|
|
367
373
|
else:
|
368
374
|
self.reduce_to_chunk_rep = AttentionPool(dim, chunk_size = chunk_size)
|
369
375
|
|
370
|
-
# learned adaptive learning rate
|
371
|
-
|
372
|
-
self.to_momentum = Sequential(
|
373
|
-
nn.Linear(dim, heads),
|
374
|
-
Rearrange('b n h -> (b h) n 1')
|
375
|
-
) if momentum else None
|
376
|
+
# learned adaptive learning rate
|
376
377
|
|
377
378
|
self.to_adaptive_step = Sequential(
|
378
379
|
nn.Linear(dim, heads),
|
@@ -384,6 +385,26 @@ class NeuralMemory(Module):
|
|
384
385
|
|
385
386
|
self.adaptive_step_transform = adaptive_step_transform
|
386
387
|
|
388
|
+
# momentum related
|
389
|
+
|
390
|
+
self.to_momentum = Sequential(
|
391
|
+
nn.Linear(dim, heads * momentum_order),
|
392
|
+
Rearrange('b n (h o) -> o (b h) n 1', o = momentum_order)
|
393
|
+
) if momentum else None
|
394
|
+
|
395
|
+
self.momentum_order = momentum_order
|
396
|
+
self.to_learned_momentum_combine = None
|
397
|
+
|
398
|
+
if learned_momentum_combine:
|
399
|
+
assert momentum
|
400
|
+
assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
|
401
|
+
|
402
|
+
self.to_learned_momentum_combine = Sequential(
|
403
|
+
nn.Linear(dim, heads * momentum_order),
|
404
|
+
nn.Softmax(dim = -1),
|
405
|
+
Rearrange('b n (h o) -> o (b h) n', h = heads)
|
406
|
+
)
|
407
|
+
|
387
408
|
# per layer learning rate modulation
|
388
409
|
|
389
410
|
self.to_layer_modulation = Sequential(
|
@@ -463,9 +484,9 @@ class NeuralMemory(Module):
|
|
463
484
|
zeros = self.memory_model_parameter_dict.clone().zero_()
|
464
485
|
|
465
486
|
if self.per_head_learned_parameters:
|
466
|
-
zeros = repeat_dict_values(zeros, 'h ... -> (b h) ...', b = batch)
|
487
|
+
zeros = repeat_dict_values(zeros, 'h ... -> o (b h) ...', b = batch, o = self.momentum_order)
|
467
488
|
else:
|
468
|
-
zeros = repeat_dict_values(zeros, '... -> bh ...', bh = batch * self.heads)
|
489
|
+
zeros = repeat_dict_values(zeros, '... -> o bh ...', bh = batch * self.heads, o = self.momentum_order)
|
469
490
|
|
470
491
|
return zeros
|
471
492
|
|
@@ -518,6 +539,11 @@ class NeuralMemory(Module):
|
|
518
539
|
if has_momentum:
|
519
540
|
adaptive_momentum = self.to_momentum(chunked_seq).sigmoid()
|
520
541
|
|
542
|
+
learned_combine = exists(self.to_learned_momentum_combine)
|
543
|
+
|
544
|
+
if learned_combine:
|
545
|
+
combine_momentums = self.to_learned_momentum_combine(chunked_seq)
|
546
|
+
|
521
547
|
if need_layer_lr_mod:
|
522
548
|
layer_lr_mod = self.to_layer_modulation(chunked_seq) * self.max_mem_layer_modulation
|
523
549
|
|
@@ -585,7 +611,7 @@ class NeuralMemory(Module):
|
|
585
611
|
|
586
612
|
# negative gradients, adaptive lr already applied as loss weight
|
587
613
|
|
588
|
-
surprises = grads.
|
614
|
+
surprises = grads.mul(-1)
|
589
615
|
|
590
616
|
# past states
|
591
617
|
|
@@ -611,7 +637,6 @@ class NeuralMemory(Module):
|
|
611
637
|
|
612
638
|
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
|
613
639
|
|
614
|
-
next_momentum = TensorDict() if has_momentum else None
|
615
640
|
updates = TensorDict()
|
616
641
|
|
617
642
|
next_last_update = TensorDict()
|
@@ -624,20 +649,34 @@ class NeuralMemory(Module):
|
|
624
649
|
# derive momentum with associative scan - eq (10)
|
625
650
|
|
626
651
|
if has_momentum:
|
652
|
+
momentum = surprise
|
653
|
+
|
654
|
+
momentums = [] # stores all momentum orders starting with first, to generalize to Nth order momentum
|
655
|
+
|
627
656
|
last_momentum = past_last_momentum[param_name]
|
628
|
-
|
629
|
-
momentum
|
630
|
-
|
657
|
+
|
658
|
+
# go from first order momentum all the way to the Nth
|
659
|
+
|
660
|
+
for one_adaptive_momentum, one_last_momentum in zip_longest(adaptive_momentum, last_momentum):
|
661
|
+
momentum = self.assoc_scan(one_adaptive_momentum, momentum, prev = one_last_momentum) # momentum is S / surprise in the paper
|
662
|
+
|
663
|
+
momentums.append(momentum)
|
664
|
+
|
665
|
+
momentums = torch.stack(momentums)
|
666
|
+
|
667
|
+
next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
|
668
|
+
|
669
|
+
if not learned_combine:
|
670
|
+
update = momentums[-1]
|
671
|
+
else:
|
672
|
+
update = einsum(combine_momentums, momentums, 'o b n, o b n ... -> b n ...')
|
631
673
|
|
632
674
|
# use associative scan again for learned forgetting (weight decay) - eq (13)
|
633
675
|
|
634
676
|
update = self.assoc_scan(1. - decay_factor, update, prev = last_update, remove_prev = False)
|
635
|
-
next_last_update[param_name] = update[:, -1]
|
636
677
|
|
637
678
|
updates[param_name] = update
|
638
|
-
|
639
|
-
if has_momentum:
|
640
|
-
next_momentum[param_name] = momentum
|
679
|
+
next_last_update[param_name] = update[:, -1]
|
641
680
|
|
642
681
|
# determine next state for the storing of memories
|
643
682
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|