titans-pytorch 0.0.50__py3-none-any.whl → 0.0.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +33 -1
- titans_pytorch/titans.py +14 -5
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.51.dist-info}/METADATA +10 -1
- titans_pytorch-0.0.51.dist-info/RECORD +8 -0
- titans_pytorch-0.0.50.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.51.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.51.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,16 +7,48 @@ from torch import nn, cat
|
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
|
9
9
|
|
|
10
|
+
# flex attention
|
|
11
|
+
# https://pytorch.org/blog/flexattention/
|
|
12
|
+
|
|
13
|
+
flex_attention = None
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
flex_attention = torch.compile(flex_attention)
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
23
|
+
|
|
24
|
+
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
25
|
+
is_persist_mem = kv_idx < persist_mem_len
|
|
26
|
+
causal_mask = q_idx >= (kv_idx - is_persist_mem)
|
|
27
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
|
|
28
|
+
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
29
|
+
|
|
30
|
+
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
31
|
+
return block_mask
|
|
32
|
+
|
|
33
|
+
# einstein notation related
|
|
34
|
+
|
|
10
35
|
from einops import einsum, repeat, rearrange, pack, unpack
|
|
11
36
|
from einops.layers.torch import Rearrange
|
|
12
37
|
|
|
13
|
-
|
|
38
|
+
# b - batch
|
|
39
|
+
# n - sequence
|
|
40
|
+
# h - heads
|
|
41
|
+
# d - feature dimension
|
|
14
42
|
|
|
15
43
|
# absolute and relative positions
|
|
16
44
|
|
|
17
45
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
|
18
46
|
from rotary_embedding_torch import RotaryEmbedding
|
|
47
|
+
|
|
48
|
+
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
|
49
|
+
|
|
19
50
|
from x_transformers.attend import Attend
|
|
51
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
20
52
|
|
|
21
53
|
# proposed neural memory
|
|
22
54
|
|
titans_pytorch/titans.py
CHANGED
|
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
|
|
|
56
56
|
|
|
57
57
|
return packed, inverse
|
|
58
58
|
|
|
59
|
+
def Sequential(*modules):
|
|
60
|
+
modules = [*filter(exists, modules)]
|
|
61
|
+
|
|
62
|
+
if len(modules) == 0:
|
|
63
|
+
return nn.Identity()
|
|
64
|
+
|
|
65
|
+
if len(modules) == 1:
|
|
66
|
+
return modules[0]
|
|
67
|
+
|
|
68
|
+
return nn.Sequential(*modules)
|
|
69
|
+
|
|
59
70
|
# softclamping gradients
|
|
60
71
|
|
|
61
72
|
def softclamp_max(t, max_value):
|
|
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
|
|
|
124
135
|
])
|
|
125
136
|
|
|
126
137
|
def forward(self, x):
|
|
127
|
-
|
|
128
|
-
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
129
|
-
|
|
130
138
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
131
139
|
|
|
132
140
|
q = F.normalize(x @ wq, dim = -1)
|
|
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
|
|
|
168
176
|
post_rmsnorm = True,
|
|
169
177
|
max_grad_norm: float | None = None,
|
|
170
178
|
use_accelerated_scan = False,
|
|
179
|
+
activation: Module | None = None,
|
|
171
180
|
default_model_kwargs: dict = dict(
|
|
172
181
|
depth = 2
|
|
173
182
|
)
|
|
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
|
|
|
225
234
|
|
|
226
235
|
# queries for retrieving from the model
|
|
227
236
|
|
|
228
|
-
self.to_queries = LinearNoBias(dim, dim_inner)
|
|
237
|
+
self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
|
|
229
238
|
|
|
230
239
|
# keys and values for storing to the model
|
|
231
240
|
|
|
232
|
-
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
|
241
|
+
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
|
233
242
|
self.store_memory_loss_fn = store_memory_loss_fn
|
|
234
243
|
|
|
235
244
|
# empty memory embed
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.51
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -135,3 +135,12 @@ $ python train_mac.py
|
|
|
135
135
|
year = {2024}
|
|
136
136
|
}
|
|
137
137
|
```
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@inproceedings{Yang2024GatedDN,
|
|
141
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
142
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
143
|
+
year = {2024},
|
|
144
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
145
|
+
}
|
|
146
|
+
```
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=R_lWRe_iBkiWZ1iZAt1tNyjaTUyB5mb80mcYZHUKww0,11369
|
|
4
|
+
titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
|
|
5
|
+
titans_pytorch-0.0.51.dist-info/METADATA,sha256=OKqRYWucpjsLKKOksBrdHnvCAqgHgv2FUtL6l8hN2-Y,4484
|
|
6
|
+
titans_pytorch-0.0.51.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.51.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.51.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
|
|
4
|
-
titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
|
|
5
|
-
titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
|
|
6
|
-
titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.50.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|