titans-pytorch 0.0.50__py3-none-any.whl → 0.0.52__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- titans_pytorch/mac_transformer.py +104 -3
- titans_pytorch/titans.py +14 -5
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.52.dist-info}/METADATA +10 -1
- titans_pytorch-0.0.52.dist-info/RECORD +8 -0
- titans_pytorch-0.0.50.dist-info/RECORD +0 -8
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.52.dist-info}/WHEEL +0 -0
- {titans_pytorch-0.0.50.dist-info → titans_pytorch-0.0.52.dist-info}/licenses/LICENSE +0 -0
|
@@ -7,16 +7,48 @@ from torch import nn, cat
|
|
|
7
7
|
import torch.nn.functional as F
|
|
8
8
|
from torch.nn import Module, ModuleList, Linear
|
|
9
9
|
|
|
10
|
+
# flex attention
|
|
11
|
+
# https://pytorch.org/blog/flexattention/
|
|
12
|
+
|
|
13
|
+
flex_attention = None
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
17
|
+
if torch.cuda.is_available():
|
|
18
|
+
flex_attention = torch.compile(flex_attention)
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def create_mac_block_mask(seq_len, window_size, persist_mem_len):
|
|
23
|
+
|
|
24
|
+
def create_mac_mask(b, h, q_idx, kv_idx):
|
|
25
|
+
is_persist_mem = kv_idx < persist_mem_len
|
|
26
|
+
causal_mask = q_idx >= (kv_idx - is_persist_mem)
|
|
27
|
+
block_diagonal = (q_idx // window_size) == ((kv_idx - is_persist_mem) // window_size)
|
|
28
|
+
return is_persist_mem | (~is_persist_mem & (causal_mask & block_diagonal))
|
|
29
|
+
|
|
30
|
+
block_mask = create_block_mask(create_mac_mask, B = None, H = None, Q_LEN = seq_len, KV_LEN = seq_len + persist_mem_len, _compile = True)
|
|
31
|
+
return block_mask
|
|
32
|
+
|
|
33
|
+
# einstein notation related
|
|
34
|
+
|
|
10
35
|
from einops import einsum, repeat, rearrange, pack, unpack
|
|
11
36
|
from einops.layers.torch import Rearrange
|
|
12
37
|
|
|
13
|
-
|
|
38
|
+
# b - batch
|
|
39
|
+
# n - sequence
|
|
40
|
+
# h - heads
|
|
41
|
+
# d - feature dimension
|
|
14
42
|
|
|
15
43
|
# absolute and relative positions
|
|
16
44
|
|
|
17
45
|
from axial_positional_embedding import ContinuousAxialPositionalEmbedding
|
|
18
46
|
from rotary_embedding_torch import RotaryEmbedding
|
|
47
|
+
|
|
48
|
+
# hyper connections / attend from x-transformers, which handles different queries and key lengths better
|
|
49
|
+
|
|
19
50
|
from x_transformers.attend import Attend
|
|
51
|
+
from hyper_connections import get_init_and_expand_reduce_stream_functions
|
|
20
52
|
|
|
21
53
|
# proposed neural memory
|
|
22
54
|
|
|
@@ -96,6 +128,7 @@ class SegmentedAttention(Module):
|
|
|
96
128
|
heads = 8,
|
|
97
129
|
accept_value_residual = False,
|
|
98
130
|
attend_kwargs: dict = dict(),
|
|
131
|
+
use_flex_attn = False
|
|
99
132
|
):
|
|
100
133
|
super().__init__()
|
|
101
134
|
self.norm = nn.RMSNorm(dim)
|
|
@@ -125,11 +158,79 @@ class SegmentedAttention(Module):
|
|
|
125
158
|
|
|
126
159
|
self.persistent_memory = nn.Parameter(torch.zeros(2, heads, num_persist_mem_tokens, dim_head))
|
|
127
160
|
|
|
161
|
+
# flex attn related
|
|
162
|
+
|
|
163
|
+
assert not (use_flex_attn and not exists(flex_attention)), 'you need to be on the latest pytorch with a cuda device available'
|
|
164
|
+
self.use_flex_attn = use_flex_attn
|
|
165
|
+
|
|
166
|
+
self.segment_len = segment_len
|
|
167
|
+
self.num_persist_mem_tokens = num_persist_mem_tokens
|
|
168
|
+
|
|
169
|
+
def forward_flex(
|
|
170
|
+
self,
|
|
171
|
+
seq,
|
|
172
|
+
value_residual = None,
|
|
173
|
+
flex_attn_fn: Callable | None = None
|
|
174
|
+
):
|
|
175
|
+
|
|
176
|
+
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
177
|
+
|
|
178
|
+
batch, seq_len = seq.shape[:2]
|
|
179
|
+
|
|
180
|
+
# attention
|
|
181
|
+
|
|
182
|
+
seq = self.norm(seq)
|
|
183
|
+
|
|
184
|
+
q, k, v = self.to_qkv(seq).chunk(3, dim = -1)
|
|
185
|
+
q, k, v = map(self.split_heads, (q, k, v))
|
|
186
|
+
|
|
187
|
+
# value residual
|
|
188
|
+
|
|
189
|
+
orig_v = v
|
|
190
|
+
|
|
191
|
+
if exists(self.to_learned_v_mix):
|
|
192
|
+
mix = self.to_learned_v_mix(seq)
|
|
193
|
+
v = v.lerp(value_residual, mix)
|
|
194
|
+
|
|
195
|
+
# take care of persistent memory key / values
|
|
196
|
+
|
|
197
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv h n d -> kv b h n d', b = batch)
|
|
198
|
+
|
|
199
|
+
# relative positions
|
|
200
|
+
|
|
201
|
+
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
|
|
202
|
+
|
|
203
|
+
# persistent memory
|
|
204
|
+
|
|
205
|
+
k = cat((pmk, k), dim = -2)
|
|
206
|
+
v = cat((pmv, v), dim = -2)
|
|
207
|
+
|
|
208
|
+
# prep flex attention
|
|
209
|
+
|
|
210
|
+
if not exists(flex_attn_fn):
|
|
211
|
+
block_mask = create_mac_block_mask(seq_len, self.segment_len, self.num_persist_mem_tokens)
|
|
212
|
+
|
|
213
|
+
flex_attn_fn = partial(flex_attention, block_mask = block_mask)
|
|
214
|
+
|
|
215
|
+
# attention
|
|
216
|
+
|
|
217
|
+
out = flex_attn_fn(q, k, v)
|
|
218
|
+
|
|
219
|
+
out = self.merge_heads(out)
|
|
220
|
+
|
|
221
|
+
out = self.to_out(out)
|
|
222
|
+
|
|
223
|
+
return out, orig_v
|
|
224
|
+
|
|
128
225
|
def forward(
|
|
129
226
|
self,
|
|
130
227
|
seq,
|
|
131
|
-
value_residual = None
|
|
228
|
+
value_residual = None,
|
|
229
|
+
flex_attn_fn: Callable | None = None
|
|
132
230
|
):
|
|
231
|
+
if seq.is_cuda and self.use_flex_attn:
|
|
232
|
+
return self.forward_flex(seq, value_residual, flex_attn_fn)
|
|
233
|
+
|
|
133
234
|
assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
|
|
134
235
|
|
|
135
236
|
segment_len, num_longterm_mem_tokens = self.segment_len, self.num_longterm_mem_tokens
|
|
@@ -159,7 +260,7 @@ class SegmentedAttention(Module):
|
|
|
159
260
|
|
|
160
261
|
# take care of persistent memory key / values
|
|
161
262
|
|
|
162
|
-
pmk, pmv =
|
|
263
|
+
pmk, pmv = repeat(self.persistent_memory, 'kv ... -> kv b ...', b = seq.shape[0])
|
|
163
264
|
|
|
164
265
|
# relative positions
|
|
165
266
|
|
titans_pytorch/titans.py
CHANGED
|
@@ -56,6 +56,17 @@ def pack_one_with_inverse(t, pattern):
|
|
|
56
56
|
|
|
57
57
|
return packed, inverse
|
|
58
58
|
|
|
59
|
+
def Sequential(*modules):
|
|
60
|
+
modules = [*filter(exists, modules)]
|
|
61
|
+
|
|
62
|
+
if len(modules) == 0:
|
|
63
|
+
return nn.Identity()
|
|
64
|
+
|
|
65
|
+
if len(modules) == 1:
|
|
66
|
+
return modules[0]
|
|
67
|
+
|
|
68
|
+
return nn.Sequential(*modules)
|
|
69
|
+
|
|
59
70
|
# softclamping gradients
|
|
60
71
|
|
|
61
72
|
def softclamp_max(t, max_value):
|
|
@@ -124,9 +135,6 @@ class MemoryAttention(Module):
|
|
|
124
135
|
])
|
|
125
136
|
|
|
126
137
|
def forward(self, x):
|
|
127
|
-
|
|
128
|
-
assert x.shape[-2] > 1, 'chunk size needs to be greater than 1 for using attention as memory'
|
|
129
|
-
|
|
130
138
|
wq, wk, wv, ffw1, ffw2 = self.weights
|
|
131
139
|
|
|
132
140
|
q = F.normalize(x @ wq, dim = -1)
|
|
@@ -168,6 +176,7 @@ class NeuralMemory(Module):
|
|
|
168
176
|
post_rmsnorm = True,
|
|
169
177
|
max_grad_norm: float | None = None,
|
|
170
178
|
use_accelerated_scan = False,
|
|
179
|
+
activation: Module | None = None,
|
|
171
180
|
default_model_kwargs: dict = dict(
|
|
172
181
|
depth = 2
|
|
173
182
|
)
|
|
@@ -225,11 +234,11 @@ class NeuralMemory(Module):
|
|
|
225
234
|
|
|
226
235
|
# queries for retrieving from the model
|
|
227
236
|
|
|
228
|
-
self.to_queries = LinearNoBias(dim, dim_inner)
|
|
237
|
+
self.to_queries = Sequential(LinearNoBias(dim, dim_inner), activation)
|
|
229
238
|
|
|
230
239
|
# keys and values for storing to the model
|
|
231
240
|
|
|
232
|
-
self.to_keys_values = LinearNoBias(dim, dim_inner * 2)
|
|
241
|
+
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
|
|
233
242
|
self.store_memory_loss_fn = store_memory_loss_fn
|
|
234
243
|
|
|
235
244
|
# empty memory embed
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: titans-pytorch
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.52
|
|
4
4
|
Summary: Titans
|
|
5
5
|
Project-URL: Homepage, https://pypi.org/project/titans-pytorch/
|
|
6
6
|
Project-URL: Repository, https://github.com/lucidrains/titans-pytorch
|
|
@@ -135,3 +135,12 @@ $ python train_mac.py
|
|
|
135
135
|
year = {2024}
|
|
136
136
|
}
|
|
137
137
|
```
|
|
138
|
+
|
|
139
|
+
```bibtex
|
|
140
|
+
@inproceedings{Yang2024GatedDN,
|
|
141
|
+
title = {Gated Delta Networks: Improving Mamba2 with Delta Rule},
|
|
142
|
+
author = {Songlin Yang and Jan Kautz and Ali Hatamizadeh},
|
|
143
|
+
year = {2024},
|
|
144
|
+
url = {https://api.semanticscholar.org/CorpusID:274598177}
|
|
145
|
+
}
|
|
146
|
+
```
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
+
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
+
titans_pytorch/mac_transformer.py,sha256=khfjpbsy-uT9NIG3dZLsLOG_XSEi7EqcyfbPr7EQc2Q,13192
|
|
4
|
+
titans_pytorch/titans.py,sha256=T04onF0xhcrosS-Qkx7fcx-Cqgh0TdU5JLdq9l8ayGg,15911
|
|
5
|
+
titans_pytorch-0.0.52.dist-info/METADATA,sha256=coC9ExIuNvmab0BktSE1NwUgxRaBUV7h_cTHeoJkRJo,4484
|
|
6
|
+
titans_pytorch-0.0.52.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
+
titans_pytorch-0.0.52.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
+
titans_pytorch-0.0.52.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
titans_pytorch/__init__.py,sha256=wnv_Cjdjqh_h5IqLkQ8xrTtA2K663ITEn-1JeeHofTo,150
|
|
2
|
-
titans_pytorch/associative_scan.py,sha256=Y-iYqmFuG-NoCKu6kgql1mhowXTeJfyawi3eUIXamp0,2650
|
|
3
|
-
titans_pytorch/mac_transformer.py,sha256=EMhxPt86Vr6LFvPm0OLMFYLaIY19khU9yIHkIhl2EMA,10316
|
|
4
|
-
titans_pytorch/titans.py,sha256=TklMAxNDxgFBpJZFJa8hEhqA_DITmT6EM0p0ueE1jo8,15712
|
|
5
|
-
titans_pytorch-0.0.50.dist-info/METADATA,sha256=KU7TTrH89eNVPP10NKKTDKnW-ik344_kVQkAXW7NRL8,4210
|
|
6
|
-
titans_pytorch-0.0.50.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
7
|
-
titans_pytorch-0.0.50.dist-info/licenses/LICENSE,sha256=1yCiA9b5nhslTavxPjsQAO-wpOnwJR9-l8LTVi7GJuk,1066
|
|
8
|
-
titans_pytorch-0.0.50.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|