titans-pytorch 0.3.11__tar.gz → 0.3.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/PKG-INFO +1 -1
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/pyproject.toml +1 -1
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/tests/test_titans.py +5 -2
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/titans_pytorch/neural_memory.py +12 -2
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/.gitignore +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/LICENSE +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/README.md +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/data/README.md +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/data/enwik8.gz +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/fig1.png +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/fig2.png +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/titans_pytorch/__init__.py +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.3.11 → titans_pytorch-0.3.12}/train_mac.py +0 -0
@@ -61,8 +61,10 @@ def test_titans(
|
|
61
61
|
assert seq.shape == retrieved.shape
|
62
62
|
|
63
63
|
@pytest.mark.parametrize('learned_momentum_combine', (False, True))
|
64
|
+
@pytest.mark.parametrize('learned_combine_include_zeroth', (False, True))
|
64
65
|
def test_titans_second_order_momentum(
|
65
|
-
learned_momentum_combine
|
66
|
+
learned_momentum_combine,
|
67
|
+
learned_combine_include_zeroth
|
66
68
|
):
|
67
69
|
|
68
70
|
mem = NeuralMemory(
|
@@ -72,7 +74,8 @@ def test_titans_second_order_momentum(
|
|
72
74
|
chunk_size = 1,
|
73
75
|
batch_size = 2,
|
74
76
|
momentum_order = 2,
|
75
|
-
learned_momentum_combine = learned_momentum_combine
|
77
|
+
learned_momentum_combine = learned_momentum_combine,
|
78
|
+
learned_combine_include_zeroth = learned_combine_include_zeroth
|
76
79
|
)
|
77
80
|
|
78
81
|
seq = torch.randn(2, 5, 384)
|
@@ -7,7 +7,7 @@ from itertools import zip_longest
|
|
7
7
|
from collections import namedtuple
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from torch import nn, cat, tensor, Tensor
|
10
|
+
from torch import nn, stack, cat, tensor, Tensor
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
13
13
|
from torch.func import functional_call, vmap, grad
|
@@ -230,6 +230,7 @@ class NeuralMemory(Module):
|
|
230
230
|
momentum = True,
|
231
231
|
momentum_order = 1,
|
232
232
|
learned_momentum_combine = False,
|
233
|
+
learned_combine_include_zeroth = False,
|
233
234
|
pre_rmsnorm = True,
|
234
235
|
post_rmsnorm = False,
|
235
236
|
qk_rmsnorm = False,
|
@@ -399,12 +400,17 @@ class NeuralMemory(Module):
|
|
399
400
|
assert momentum
|
400
401
|
assert momentum_order > 1, 'only second order momentum allowed for now, but may allow learned combination of zeroth'
|
401
402
|
|
403
|
+
if learned_combine_include_zeroth:
|
404
|
+
momentum_order += 1
|
405
|
+
|
402
406
|
self.to_learned_momentum_combine = Sequential(
|
403
407
|
nn.Linear(dim, heads * momentum_order),
|
404
408
|
nn.Softmax(dim = -1),
|
405
409
|
Rearrange('b n (h o) -> o (b h) n', h = heads)
|
406
410
|
)
|
407
411
|
|
412
|
+
self.learned_combine_include_zeroth = learned_combine_include_zeroth
|
413
|
+
|
408
414
|
# per layer learning rate modulation
|
409
415
|
|
410
416
|
self.to_layer_modulation = Sequential(
|
@@ -662,10 +668,14 @@ class NeuralMemory(Module):
|
|
662
668
|
|
663
669
|
momentums.append(momentum)
|
664
670
|
|
665
|
-
momentums =
|
671
|
+
momentums = stack(momentums)
|
666
672
|
|
667
673
|
next_last_momentum[param_name] = momentums[:, :, -1] # momentums shape is Float['o bh n 1']
|
668
674
|
|
675
|
+
if learned_combine and self.learned_combine_include_zeroth:
|
676
|
+
# add the original surprise if learned combination of momentums
|
677
|
+
momentums = cat((rearrange(surprise, '... -> 1 ...'), momentums), dim = 0)
|
678
|
+
|
669
679
|
if not learned_combine:
|
670
680
|
update = momentums[-1]
|
671
681
|
else:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|