titans-pytorch 0.4.3__tar.gz → 0.4.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/PKG-INFO +2 -2
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/README.md +1 -1
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/pyproject.toml +1 -1
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/tests/test_titans.py +20 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/titans_pytorch/__init__.py +2 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/titans_pytorch/neural_memory.py +17 -1
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/.gitignore +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/LICENSE +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/data/README.md +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/data/enwik8.gz +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/fig1.png +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/fig2.png +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.4.3 → titans_pytorch-0.4.5}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.5
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
<img src="./fig1.png" width="400px"></img>
|
4
4
|
|
5
|
-
## Titans - Pytorch
|
5
|
+
## Titans - Pytorch
|
6
6
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
8
|
|
@@ -405,3 +405,23 @@ def test_assoc_scan(
|
|
405
405
|
assert second_half.shape == inputs2.shape
|
406
406
|
|
407
407
|
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
|
408
|
+
|
409
|
+
def test_mem_state_detach():
|
410
|
+
from titans_pytorch.neural_memory import mem_state_detach
|
411
|
+
|
412
|
+
mem = NeuralMemory(
|
413
|
+
dim = 384,
|
414
|
+
chunk_size = 2,
|
415
|
+
qk_rmsnorm = True,
|
416
|
+
dim_head = 64,
|
417
|
+
heads = 4,
|
418
|
+
)
|
419
|
+
|
420
|
+
seq = torch.randn(4, 64, 384)
|
421
|
+
|
422
|
+
state = None
|
423
|
+
|
424
|
+
for _ in range(2):
|
425
|
+
parallel_retrieved, state = mem(seq, state = state)
|
426
|
+
state = mem_state_detach(state)
|
427
|
+
parallel_retrieved.sum().backward()
|
@@ -7,10 +7,11 @@ from itertools import zip_longest
|
|
7
7
|
from collections import namedtuple
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from torch import nn, stack, cat, tensor, Tensor
|
10
|
+
from torch import nn, stack, cat, is_tensor, tensor, Tensor
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
13
13
|
from torch.func import functional_call, vmap, grad
|
14
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
14
15
|
|
15
16
|
from tensordict import TensorDict
|
16
17
|
|
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
|
|
40
41
|
|
41
42
|
LinearNoBias = partial(Linear, bias = False)
|
42
43
|
|
44
|
+
# neural mem state related
|
45
|
+
|
43
46
|
NeuralMemState = namedtuple('NeuralMemState', [
|
44
47
|
'seq_index',
|
45
48
|
'weights',
|
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
|
|
48
51
|
'updates',
|
49
52
|
])
|
50
53
|
|
54
|
+
def mem_state_detach(
|
55
|
+
state: NeuralMemState
|
56
|
+
):
|
57
|
+
assert isinstance(state, NeuralMemState)
|
58
|
+
state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
|
59
|
+
return NeuralMemState(*state)
|
60
|
+
|
51
61
|
# functions
|
52
62
|
|
53
63
|
def exists(v):
|
@@ -854,6 +864,7 @@ class NeuralMemory(Module):
|
|
854
864
|
seq,
|
855
865
|
store_seq = None,
|
856
866
|
state: NeuralMemState | None = None,
|
867
|
+
detach_mem_state = False,
|
857
868
|
prev_weights = None,
|
858
869
|
store_mask: Tensor | None = None,
|
859
870
|
return_surprises = False
|
@@ -1003,6 +1014,11 @@ class NeuralMemory(Module):
|
|
1003
1014
|
updates
|
1004
1015
|
)
|
1005
1016
|
|
1017
|
+
# maybe detach
|
1018
|
+
|
1019
|
+
if detach_mem_state:
|
1020
|
+
next_neural_mem_state = mem_state_detach(next_neural_mem_state)
|
1021
|
+
|
1006
1022
|
# returning
|
1007
1023
|
|
1008
1024
|
if not return_surprises:
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|