titans-pytorch 0.4.2__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/PKG-INFO +2 -2
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/README.md +1 -1
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/pyproject.toml +1 -1
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/tests/test_titans.py +20 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/titans_pytorch/__init__.py +2 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/titans_pytorch/neural_memory.py +13 -3
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/.github/workflows/python-publish.yml +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/.github/workflows/test.yaml +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/.gitignore +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/LICENSE +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/data/README.md +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/data/enwik8.gz +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/fig1.png +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/fig2.png +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/titans_pytorch/associative_scan.py +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/titans_pytorch/mac_transformer.py +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/titans_pytorch/memory_models.py +0 -0
- {titans_pytorch-0.4.2 → titans_pytorch-0.4.4}/train_mac.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.4
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
<img src="./fig1.png" width="400px"></img>
|
4
4
|
|
5
|
-
## Titans - Pytorch
|
5
|
+
## Titans - Pytorch
|
6
6
|
|
7
7
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
8
8
|
|
@@ -405,3 +405,23 @@ def test_assoc_scan(
|
|
405
405
|
assert second_half.shape == inputs2.shape
|
406
406
|
|
407
407
|
assert torch.allclose(output[:, -1], second_half[:, -1], atol = 1e-5)
|
408
|
+
|
409
|
+
def test_mem_state_detach():
|
410
|
+
from titans_pytorch.neural_memory import mem_state_detach
|
411
|
+
|
412
|
+
mem = NeuralMemory(
|
413
|
+
dim = 384,
|
414
|
+
chunk_size = 2,
|
415
|
+
qk_rmsnorm = True,
|
416
|
+
dim_head = 64,
|
417
|
+
heads = 4,
|
418
|
+
)
|
419
|
+
|
420
|
+
seq = torch.randn(4, 64, 384)
|
421
|
+
|
422
|
+
state = None
|
423
|
+
|
424
|
+
for _ in range(2):
|
425
|
+
parallel_retrieved, state = mem(seq, state = state)
|
426
|
+
state = mem_state_detach(state)
|
427
|
+
parallel_retrieved.sum().backward()
|
@@ -7,10 +7,11 @@ from itertools import zip_longest
|
|
7
7
|
from collections import namedtuple
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from torch import nn, stack, cat, tensor, Tensor
|
10
|
+
from torch import nn, stack, cat, is_tensor, tensor, Tensor
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
13
13
|
from torch.func import functional_call, vmap, grad
|
14
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
14
15
|
|
15
16
|
from tensordict import TensorDict
|
16
17
|
|
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
|
|
40
41
|
|
41
42
|
LinearNoBias = partial(Linear, bias = False)
|
42
43
|
|
44
|
+
# neural mem state related
|
45
|
+
|
43
46
|
NeuralMemState = namedtuple('NeuralMemState', [
|
44
47
|
'seq_index',
|
45
48
|
'weights',
|
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
|
|
48
51
|
'updates',
|
49
52
|
])
|
50
53
|
|
54
|
+
def mem_state_detach(
|
55
|
+
state: NeuralMemState
|
56
|
+
):
|
57
|
+
assert isinstance(state, NeuralMemState)
|
58
|
+
state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
|
59
|
+
return NeuralMemState(*state)
|
60
|
+
|
51
61
|
# functions
|
52
62
|
|
53
63
|
def exists(v):
|
@@ -940,7 +950,7 @@ class NeuralMemory(Module):
|
|
940
950
|
|
941
951
|
# whether to allow network to slowly adjust from initial weight throughout (residual path) to fully updating weights every batch
|
942
952
|
|
943
|
-
surprises = None
|
953
|
+
surprises = (None, None)
|
944
954
|
gate = None
|
945
955
|
|
946
956
|
if exists(self.transition_gate):
|
@@ -967,7 +977,7 @@ class NeuralMemory(Module):
|
|
967
977
|
|
968
978
|
updates = accum_updates(updates, next_updates)
|
969
979
|
|
970
|
-
surprises = safe_cat(
|
980
|
+
surprises = tuple(safe_cat(args, dim = -1) for args in zip(surprises, chunk_surprises))
|
971
981
|
|
972
982
|
if is_last and not update_after_final_store:
|
973
983
|
continue
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|