titans-pytorch 0.4.3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/__init__.py +2 -0
- titans_pytorch/neural_memory.py +11 -1
- {titans_pytorch-0.4.3.dist-info → titans_pytorch-0.4.4.dist-info}/METADATA +2 -2
- titans_pytorch-0.4.4.dist-info/RECORD +9 -0
- titans_pytorch-0.4.3.dist-info/RECORD +0 -9
- {titans_pytorch-0.4.3.dist-info → titans_pytorch-0.4.4.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.4.3.dist-info → titans_pytorch-0.4.4.dist-info}/licenses/LICENSE +0 -0
titans_pytorch/__init__.py
CHANGED
titans_pytorch/neural_memory.py
CHANGED
@@ -7,10 +7,11 @@ from itertools import zip_longest
|
|
7
7
|
from collections import namedtuple
|
8
8
|
|
9
9
|
import torch
|
10
|
-
from torch import nn, stack, cat, tensor, Tensor
|
10
|
+
from torch import nn, stack, cat, is_tensor, tensor, Tensor
|
11
11
|
import torch.nn.functional as F
|
12
12
|
from torch.nn import Linear, Module, Parameter, ParameterList, ParameterDict
|
13
13
|
from torch.func import functional_call, vmap, grad
|
14
|
+
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
14
15
|
|
15
16
|
from tensordict import TensorDict
|
16
17
|
|
@@ -40,6 +41,8 @@ u - key / value updates - allowing a token to emit multiple key / values
|
|
40
41
|
|
41
42
|
LinearNoBias = partial(Linear, bias = False)
|
42
43
|
|
44
|
+
# neural mem state related
|
45
|
+
|
43
46
|
NeuralMemState = namedtuple('NeuralMemState', [
|
44
47
|
'seq_index',
|
45
48
|
'weights',
|
@@ -48,6 +51,13 @@ NeuralMemState = namedtuple('NeuralMemState', [
|
|
48
51
|
'updates',
|
49
52
|
])
|
50
53
|
|
54
|
+
def mem_state_detach(
|
55
|
+
state: NeuralMemState
|
56
|
+
):
|
57
|
+
assert isinstance(state, NeuralMemState)
|
58
|
+
state = tree_map(lambda t: t.detach() if is_tensor(t) else t, tuple(state))
|
59
|
+
return NeuralMemState(*state)
|
60
|
+
|
51
61
|
# functions
|
52
62
|
|
53
63
|
def exists(v):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: titans-pytorch
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.4
|
4
4
|
Summary: Titans
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
@@ -56,7 +56,7 @@ Description-Content-Type: text/markdown
|
|
56
56
|
|
57
57
|
<img src="./fig1.png" width="400px"></img>
|
58
58
|
|
59
|
-
## Titans - Pytorch
|
59
|
+
## Titans - Pytorch
|
60
60
|
|
61
61
|
Unofficial implementation of [Titans](https://arxiv.org/abs/2501.00663) in Pytorch. Will also contain some explorations into architectures beyond their simple 1-4 layer MLP for the neural memory module, if it works well to any degree.
|
62
62
|
|
@@ -0,0 +1,9 @@
|
|
1
|
+
titans_pytorch/__init__.py,sha256=sVTOuRUkaIYabFExdLY6s1qXm1UwHHz_J19H8ZV-X74,338
|
2
|
+
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
|
+
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
+
titans_pytorch/neural_memory.py,sha256=eDC8rl8S241zp8pxKrIEi_6nFm6CEoaH9K4hnDfgzu8,33145
|
6
|
+
titans_pytorch-0.4.4.dist-info/METADATA,sha256=CWciTl1VeOvwyL_lqr0JsdmDDIjXG85N8ykwd3w2TxQ,6810
|
7
|
+
titans_pytorch-0.4.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
+
titans_pytorch-0.4.4.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
+
titans_pytorch-0.4.4.dist-info/RECORD,,
|
@@ -1,9 +0,0 @@
|
|
1
|
-
titans_pytorch/__init__.py,sha256=AyEUlcXWpnqrvyeihRAXWIfQlzLA4NhBjOqQU4edL-4,297
|
2
|
-
titans_pytorch/associative_scan.py,sha256=esaLbukFlgvy2aqopsqBy6KEcZ64B3rsNhG8moKdPSc,5159
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=tz72141G5t3AOnxSVsOLtLptGtl8T7zROUvaTw2_XCY,26960
|
4
|
-
titans_pytorch/memory_models.py,sha256=wnH9i9kUSoVZhEWUlj8LpBSbB400L9kLt1zP8CO45QQ,5835
|
5
|
-
titans_pytorch/neural_memory.py,sha256=HdBaRGURJ84Qy-a6PdfeQoc5ZzY7H0c5YHUASaSVu1A,32824
|
6
|
-
titans_pytorch-0.4.3.dist-info/METADATA,sha256=SIq5KS2xehsUAwuFpRSFNdnLbgamWUMLN5xj4MJGRe0,6816
|
7
|
-
titans_pytorch-0.4.3.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
8
|
-
titans_pytorch-0.4.3.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
9
|
-
titans_pytorch-0.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|