x-transformers 1.44.8__py3-none-any.whl → 2.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +221 -137
- x_transformers-2.0.1.dist-info/METADATA +2419 -0
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info}/RECORD +5 -6
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info}/WHEEL +1 -2
- x_transformers-1.44.8.dist-info/METADATA +0 -30
- x_transformers-1.44.8.dist-info/top_level.txt +0 -1
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor
|
12
|
+
from torch import nn, einsum, Tensor, cat, stack, arange
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -18,14 +18,22 @@ from collections import namedtuple
|
|
18
18
|
from contextlib import nullcontext
|
19
19
|
from dataclasses import dataclass
|
20
20
|
|
21
|
+
from loguru import logger
|
22
|
+
|
23
|
+
from x_transformers.attend import Attend, Intermediates
|
24
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
25
|
+
|
21
26
|
import einx
|
22
27
|
from einops.layers.torch import Rearrange
|
23
28
|
from einops import rearrange, repeat, reduce, pack, unpack
|
24
29
|
|
25
|
-
|
30
|
+
# einstein notation
|
26
31
|
|
27
|
-
|
28
|
-
|
32
|
+
# b - batch
|
33
|
+
# n - sequence
|
34
|
+
# d - feature dimension
|
35
|
+
# h - attention heads
|
36
|
+
# i, j - sequence (source, target)
|
29
37
|
|
30
38
|
# constants
|
31
39
|
|
@@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout):
|
|
220
228
|
num_keep = max(1, int(keep_prob * n))
|
221
229
|
keep_indices = logits.topk(num_keep, dim = 1).indices
|
222
230
|
|
223
|
-
batch_indices =
|
231
|
+
batch_indices = arange(b, device = device)
|
224
232
|
batch_indices = rearrange(batch_indices, 'b -> b 1')
|
225
233
|
|
226
234
|
seq = seq[batch_indices, keep_indices]
|
@@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout):
|
|
228
236
|
if exists(mask):
|
229
237
|
seq_counts = mask.sum(dim = -1)
|
230
238
|
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
|
231
|
-
keep_mask =
|
239
|
+
keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
|
232
240
|
|
233
241
|
mask = mask[batch_indices, keep_indices] & keep_mask
|
234
242
|
|
@@ -274,7 +282,7 @@ class AbsolutePositionalEmbedding(Module):
|
|
274
282
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
275
283
|
|
276
284
|
if not exists(pos):
|
277
|
-
pos =
|
285
|
+
pos = arange(seq_len, device = device)
|
278
286
|
|
279
287
|
if exists(seq_start_pos):
|
280
288
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
@@ -290,7 +298,7 @@ class ScaledSinusoidalEmbedding(Module):
|
|
290
298
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
291
299
|
|
292
300
|
half_dim = dim // 2
|
293
|
-
freq_seq =
|
301
|
+
freq_seq = arange(half_dim).float() / half_dim
|
294
302
|
inv_freq = theta ** -freq_seq
|
295
303
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
296
304
|
|
@@ -298,13 +306,13 @@ class ScaledSinusoidalEmbedding(Module):
|
|
298
306
|
seq_len, device = x.shape[1], x.device
|
299
307
|
|
300
308
|
if not exists(pos):
|
301
|
-
pos =
|
309
|
+
pos = arange(seq_len, device = device)
|
302
310
|
|
303
311
|
if exists(seq_start_pos):
|
304
312
|
pos = pos - seq_start_pos[..., None]
|
305
313
|
|
306
314
|
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
307
|
-
emb =
|
315
|
+
emb = cat((emb.sin(), emb.cos()), dim = -1)
|
308
316
|
return emb * self.scale
|
309
317
|
|
310
318
|
class RelativePositionBias(Module):
|
@@ -344,8 +352,8 @@ class RelativePositionBias(Module):
|
|
344
352
|
|
345
353
|
def forward(self, i, j):
|
346
354
|
device = self.device
|
347
|
-
q_pos =
|
348
|
-
k_pos =
|
355
|
+
q_pos = arange(j - i, j, dtype = torch.long, device = device)
|
356
|
+
k_pos = arange(j, dtype = torch.long, device = device)
|
349
357
|
rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
|
350
358
|
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
351
359
|
values = self.relative_attention_bias(rp_bucket)
|
@@ -376,7 +384,7 @@ class CoPE(Module):
|
|
376
384
|
if not soft_onehot:
|
377
385
|
return
|
378
386
|
|
379
|
-
self.register_buffer('positions',
|
387
|
+
self.register_buffer('positions', arange(max_pos))
|
380
388
|
|
381
389
|
def forward(self, query, attn_logits):
|
382
390
|
|
@@ -445,13 +453,13 @@ class DynamicPositionBias(Module):
|
|
445
453
|
n, device = j, self.device
|
446
454
|
|
447
455
|
# get the (n x n) matrix of distances
|
448
|
-
seq_arange =
|
449
|
-
context_arange =
|
456
|
+
seq_arange = arange(n, device = device)
|
457
|
+
context_arange = arange(n, device = device)
|
450
458
|
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
451
459
|
indices += (n - 1)
|
452
460
|
|
453
461
|
# input to continuous positions MLP
|
454
|
-
pos =
|
462
|
+
pos = arange(-n + 1, n, device = device).float()
|
455
463
|
pos = rearrange(pos, '... -> ... 1')
|
456
464
|
|
457
465
|
if self.log_distance:
|
@@ -525,8 +533,8 @@ class AlibiPositionalBias(Module):
|
|
525
533
|
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
526
534
|
return self.bias[..., -i:, -j:]
|
527
535
|
|
528
|
-
seq_arange =
|
529
|
-
context_arange =
|
536
|
+
seq_arange = arange(j - i, j, device = device)
|
537
|
+
context_arange = arange(j, device = device)
|
530
538
|
bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
|
531
539
|
|
532
540
|
bias = bias * self.slopes
|
@@ -642,7 +650,7 @@ class RotaryEmbedding(Module):
|
|
642
650
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
643
651
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
644
652
|
|
645
|
-
inv_freq = 1. / (base ** (
|
653
|
+
inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
|
646
654
|
self.register_buffer('inv_freq', inv_freq)
|
647
655
|
|
648
656
|
assert interpolation_factor >= 1.
|
@@ -652,7 +660,7 @@ class RotaryEmbedding(Module):
|
|
652
660
|
self.register_buffer('scale', None)
|
653
661
|
return
|
654
662
|
|
655
|
-
scale = (
|
663
|
+
scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
656
664
|
|
657
665
|
self.scale_base = scale_base
|
658
666
|
self.register_buffer('scale', scale)
|
@@ -660,7 +668,7 @@ class RotaryEmbedding(Module):
|
|
660
668
|
def forward_from_seq_len(self, seq_len):
|
661
669
|
device = self.inv_freq.device
|
662
670
|
|
663
|
-
t =
|
671
|
+
t = arange(seq_len, device = device)
|
664
672
|
return self.forward(t)
|
665
673
|
|
666
674
|
@autocast('cuda', enabled = False)
|
@@ -671,7 +679,7 @@ class RotaryEmbedding(Module):
|
|
671
679
|
t = rearrange(t, 'n -> 1 n')
|
672
680
|
|
673
681
|
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
674
|
-
freqs =
|
682
|
+
freqs = stack((freqs, freqs), dim = -1)
|
675
683
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
676
684
|
|
677
685
|
if not exists(self.scale):
|
@@ -679,7 +687,7 @@ class RotaryEmbedding(Module):
|
|
679
687
|
|
680
688
|
power = (t - (max_pos // 2)) / self.scale_base
|
681
689
|
scale = self.scale ** rearrange(power, '... n -> ... n 1')
|
682
|
-
scale =
|
690
|
+
scale = stack((scale, scale), dim = -1)
|
683
691
|
scale = rearrange(scale, '... d r -> ... (d r)')
|
684
692
|
|
685
693
|
return freqs, scale
|
@@ -687,7 +695,7 @@ class RotaryEmbedding(Module):
|
|
687
695
|
def rotate_half(x):
|
688
696
|
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
689
697
|
x1, x2 = x.unbind(dim = -1)
|
690
|
-
x =
|
698
|
+
x = stack((-x2, x1), dim = -1)
|
691
699
|
return rearrange(x, '... d r -> ... (d r)')
|
692
700
|
|
693
701
|
@autocast('cuda', enabled = False)
|
@@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
703
711
|
# partial rotary embeddings, Wang et al. GPT-J
|
704
712
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
705
713
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
706
|
-
out =
|
714
|
+
out = cat((t, t_unrotated), dim = -1)
|
707
715
|
|
708
716
|
return out.type(orig_dtype)
|
709
717
|
|
@@ -833,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
833
841
|
def forward(self, x):
|
834
842
|
return F.normalize(x, dim = -1) * self.scale
|
835
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
836
853
|
# residual and residual gates
|
837
854
|
|
838
855
|
class Residual(Module):
|
@@ -904,7 +921,7 @@ class HyperConnection(Module):
|
|
904
921
|
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
905
922
|
init_alpha0[layer_index % num_residual_streams, :] = 1.
|
906
923
|
|
907
|
-
self.static_alpha = nn.Parameter(
|
924
|
+
self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
908
925
|
|
909
926
|
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
910
927
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
@@ -973,7 +990,7 @@ class ShiftTokens(Module):
|
|
973
990
|
splitted = x.split(feats_per_shift, dim = -1)
|
974
991
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
975
992
|
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
|
976
|
-
x =
|
993
|
+
x = cat((*segments_to_shift, *rest), dim = -1)
|
977
994
|
return self.fn(x, **kwargs)
|
978
995
|
|
979
996
|
class FoldAxially(Module):
|
@@ -1080,7 +1097,7 @@ class ConcatCombine(Module):
|
|
1080
1097
|
|
1081
1098
|
def forward(self, x, prev_layers: list[Tensor]):
|
1082
1099
|
skip = prev_layers[self.prev_layer_ind]
|
1083
|
-
concatted_skip =
|
1100
|
+
concatted_skip = cat((skip, x), dim = -1)
|
1084
1101
|
return self.combine(concatted_skip)
|
1085
1102
|
|
1086
1103
|
# feedforward
|
@@ -1189,12 +1206,10 @@ class Attention(Module):
|
|
1189
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1190
1207
|
one_kv_head = False,
|
1191
1208
|
kv_heads = None,
|
1192
|
-
shared_kv = False,
|
1193
1209
|
value_dim_head = None,
|
1194
1210
|
dim_out = None,
|
1195
|
-
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
1196
1211
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1197
|
-
|
1212
|
+
rotate_num_heads = None,
|
1198
1213
|
data_dependent_alibi = False,
|
1199
1214
|
data_dependent_alibi_per_row = False,
|
1200
1215
|
data_dependent_alibi_per_row_dim_head = 8,
|
@@ -1205,12 +1220,15 @@ class Attention(Module):
|
|
1205
1220
|
cope_talking_heads = False,
|
1206
1221
|
softclamp_logits = False,
|
1207
1222
|
logit_softclamp_value = 50.,
|
1208
|
-
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1209
|
-
neutreno_alpha = 0.4,
|
1210
1223
|
learned_value_residual_mix = False,
|
1211
|
-
laser = False,
|
1224
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1212
1225
|
laser_softclamp_value = 15.,
|
1213
1226
|
qkv_receive_diff_residuals = False,
|
1227
|
+
use_latent_q = False,
|
1228
|
+
dim_latent_q = None,
|
1229
|
+
use_latent_kv = False,
|
1230
|
+
dim_latent_kv = None,
|
1231
|
+
latent_rope_subheads = None,
|
1214
1232
|
onnxable = False,
|
1215
1233
|
attend_sdp_kwargs: dict = dict(
|
1216
1234
|
enable_flash = True,
|
@@ -1242,13 +1260,51 @@ class Attention(Module):
|
|
1242
1260
|
v_dim = value_dim_head * kv_heads
|
1243
1261
|
out_dim = value_dim_head * heads
|
1244
1262
|
|
1245
|
-
|
1246
|
-
|
1263
|
+
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
1264
|
+
# for eventually supporting multi-latent attention (MLA)
|
1265
|
+
|
1266
|
+
self.to_latent_q = None
|
1267
|
+
self.to_latent_kv = None
|
1268
|
+
self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
|
1269
|
+
|
1270
|
+
dim_q_input = dim
|
1271
|
+
dim_kv_input = dim_kv
|
1272
|
+
|
1273
|
+
if use_latent_q:
|
1274
|
+
assert exists(dim_latent_q)
|
1275
|
+
self.to_latent_q = LinearNoBias(dim, dim_latent_q)
|
1276
|
+
dim_q_input = dim_latent_q
|
1277
|
+
|
1278
|
+
if use_latent_kv:
|
1279
|
+
assert exists(dim_latent_kv)
|
1280
|
+
self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
|
1281
|
+
dim_kv_input = dim_latent_kv
|
1282
|
+
|
1283
|
+
if exists(latent_rope_subheads):
|
1284
|
+
assert not exists(rotate_num_heads)
|
1285
|
+
rotate_num_heads = latent_rope_subheads
|
1286
|
+
|
1287
|
+
k_dim = dim_head * (kv_heads - latent_rope_subheads)
|
1247
1288
|
|
1248
|
-
|
1289
|
+
self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
|
1290
|
+
self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
|
1249
1291
|
|
1250
|
-
|
1251
|
-
self.
|
1292
|
+
self.use_latent_q = use_latent_q
|
1293
|
+
self.use_latent_kv = use_latent_kv
|
1294
|
+
|
1295
|
+
# query key projection
|
1296
|
+
|
1297
|
+
self.to_q = LinearNoBias(dim_q_input, q_dim)
|
1298
|
+
self.to_k = LinearNoBias(dim_kv_input, k_dim)
|
1299
|
+
self.to_v = LinearNoBias(dim_kv_input, v_dim)
|
1300
|
+
|
1301
|
+
# split and merge of attention heads
|
1302
|
+
|
1303
|
+
self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
1304
|
+
self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
1305
|
+
self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
|
1306
|
+
|
1307
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
1252
1308
|
|
1253
1309
|
# whether qkv receives different residual stream combinations from hyper connections
|
1254
1310
|
|
@@ -1259,15 +1315,6 @@ class Attention(Module):
|
|
1259
1315
|
self.laser = laser
|
1260
1316
|
self.laser_softclamp_value = laser_softclamp_value
|
1261
1317
|
|
1262
|
-
# relations projection from tp-attention
|
1263
|
-
|
1264
|
-
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
1265
|
-
|
1266
|
-
# the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
|
1267
|
-
|
1268
|
-
self.neutreno_value_residual = neutreno_value_residual
|
1269
|
-
self.neutreno_alpha = neutreno_alpha
|
1270
|
-
|
1271
1318
|
# add GLU gating for aggregated values, from alphafold2
|
1272
1319
|
|
1273
1320
|
self.to_v_gate = None
|
@@ -1393,12 +1440,22 @@ class Attention(Module):
|
|
1393
1440
|
|
1394
1441
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1395
1442
|
|
1443
|
+
hybrid_mix = None
|
1444
|
+
hybrid_norms = None
|
1396
1445
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1397
1446
|
|
1398
1447
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1399
1448
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
+
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
|
1451
|
+
hybrid_norms = ModuleList([
|
1452
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1454
|
+
])
|
1400
1455
|
|
1401
1456
|
self.hybrid_module = hybrid_module
|
1457
|
+
self.hybrid_norms = hybrid_norms
|
1458
|
+
self.hybrid_mix = hybrid_mix
|
1402
1459
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1403
1460
|
|
1404
1461
|
# output dimension by default same as input, but can be overridden
|
@@ -1406,9 +1463,15 @@ class Attention(Module):
|
|
1406
1463
|
dim_out = default(dim_out, dim)
|
1407
1464
|
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1408
1465
|
|
1409
|
-
#
|
1466
|
+
# the number of attention heads to rotate, for decoupled rope in multi-latent attention
|
1467
|
+
|
1468
|
+
rotate_num_heads = default(rotate_num_heads, heads)
|
1410
1469
|
|
1411
|
-
|
1470
|
+
assert 0 < rotate_num_heads <= heads
|
1471
|
+
is_partial_rotate_heads = rotate_num_heads < heads
|
1472
|
+
assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
|
1473
|
+
|
1474
|
+
self.rotate_num_heads = rotate_num_heads
|
1412
1475
|
|
1413
1476
|
# whether parent can kv cache
|
1414
1477
|
|
@@ -1438,47 +1501,79 @@ class Attention(Module):
|
|
1438
1501
|
cache: Intermediates | None = None,
|
1439
1502
|
value_residual = None
|
1440
1503
|
):
|
1441
|
-
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
|
1504
|
+
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
1505
|
+
|
1506
|
+
# an interesting possibility with hyper connections
|
1507
|
+
# having queries, keys, values be routed from different layers
|
1442
1508
|
|
1443
1509
|
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
|
1444
1510
|
|
1445
1511
|
if qkv_receive_diff_residuals:
|
1446
|
-
assert
|
1512
|
+
assert x.ndim == 4 and x.shape[0] == 3
|
1447
1513
|
|
1448
1514
|
q_input, k_input, v_input = x
|
1449
1515
|
else:
|
1450
1516
|
kv_input = default(context, x)
|
1451
|
-
|
1452
|
-
q_input = x
|
1453
|
-
k_input = kv_input
|
1454
|
-
v_input = kv_input
|
1455
|
-
r_input = x
|
1517
|
+
q_input, k_input, v_input = x, kv_input, kv_input
|
1456
1518
|
|
1457
1519
|
if exists(mem):
|
1458
1520
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
1459
1521
|
v_input, _ = pack([mem, v_input], 'b * d')
|
1460
1522
|
|
1523
|
+
# multi-latent attention logic
|
1524
|
+
# https://arxiv.org/abs/2405.04434 - Deepseek-AI team
|
1525
|
+
|
1526
|
+
k_sub_heads = None # the rotateable subheads of keys derived from base sequence
|
1527
|
+
|
1528
|
+
if self.use_latent_q:
|
1529
|
+
q_input = self.to_latent_q(q_input)
|
1530
|
+
|
1531
|
+
if is_multi_latent_attn:
|
1532
|
+
assert not qkv_receive_diff_residuals
|
1533
|
+
needs_k_sub_heads = exists(self.to_rotateable_k)
|
1534
|
+
|
1535
|
+
latent_kv_input = self.to_latent_kv(k_input)
|
1536
|
+
|
1537
|
+
if needs_k_sub_heads:
|
1538
|
+
rotateable_k = self.to_rotateable_k(k_input)
|
1539
|
+
k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
|
1540
|
+
|
1541
|
+
if exists(cache):
|
1542
|
+
cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
|
1543
|
+
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
|
1544
|
+
|
1545
|
+
if exists(maybe_cached_k_sub_heads):
|
1546
|
+
k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
|
1547
|
+
|
1548
|
+
if return_intermediates:
|
1549
|
+
cached_kv = (latent_kv_input, k_sub_heads)
|
1550
|
+
|
1551
|
+
k_input = v_input = latent_kv_input
|
1552
|
+
|
1553
|
+
# query, key, value projection
|
1554
|
+
|
1461
1555
|
q = self.to_q(q_input)
|
1462
1556
|
k = self.to_k(k_input)
|
1463
|
-
v = self.to_v(v_input)
|
1464
|
-
|
1557
|
+
v = self.to_v(v_input)
|
1558
|
+
|
1559
|
+
q = self.split_q_heads(q)
|
1560
|
+
k = self.split_k_heads(k)
|
1561
|
+
v = self.split_v_heads(v)
|
1465
1562
|
|
1466
|
-
|
1563
|
+
# take care of decoupled rope from multi-latent attention
|
1467
1564
|
|
1468
|
-
|
1565
|
+
if exists(k_sub_heads):
|
1566
|
+
k = cat((k, k_sub_heads), dim = 1)
|
1469
1567
|
|
1470
|
-
# if previous values passed in for residual, either invoke resformer
|
1568
|
+
# if previous values passed in for residual, either invoke resformer
|
1471
1569
|
|
1472
1570
|
orig_values = v
|
1473
1571
|
|
1572
|
+
# https://arxiv.org/abs/2410.17897v1
|
1573
|
+
|
1474
1574
|
if exists(value_residual):
|
1475
|
-
|
1476
|
-
|
1477
|
-
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1478
|
-
else:
|
1479
|
-
# https://arxiv.org/abs/2410.17897v1
|
1480
|
-
value_residual_mix = self.to_value_residual_mix(q_input)
|
1481
|
-
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1575
|
+
value_residual_mix = self.to_value_residual_mix(q_input)
|
1576
|
+
v = value_residual.lerp(v, value_residual_mix)
|
1482
1577
|
|
1483
1578
|
# qk normalization
|
1484
1579
|
|
@@ -1492,28 +1587,36 @@ class Attention(Module):
|
|
1492
1587
|
|
1493
1588
|
# take care of caching
|
1494
1589
|
|
1495
|
-
if
|
1496
|
-
|
1590
|
+
if not is_multi_latent_attn:
|
1591
|
+
if exists(cache):
|
1592
|
+
ck, cv = cache.cached_kv
|
1497
1593
|
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1594
|
+
if exists(mem):
|
1595
|
+
mk, k = unpack(k, mem_packed_shape, 'b h * d')
|
1596
|
+
mv, v = unpack(v, mem_packed_shape, 'b h * d')
|
1501
1597
|
|
1502
|
-
|
1503
|
-
|
1598
|
+
k = cat((ck, k), dim = -2)
|
1599
|
+
v = cat((cv, v), dim = -2)
|
1504
1600
|
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1601
|
+
if exists(mem):
|
1602
|
+
k = cat((mk, k), dim = -2)
|
1603
|
+
v = cat((mv, v), dim = -2)
|
1508
1604
|
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1605
|
+
if return_intermediates:
|
1606
|
+
mem_len = mem.shape[-2] if exists(mem) else 0
|
1607
|
+
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1512
1608
|
|
1513
1609
|
if exists(rotary_pos_emb):
|
1610
|
+
rotate_num_heads = self.rotate_num_heads
|
1611
|
+
partial_rotate_heads = rotate_num_heads < h
|
1612
|
+
|
1514
1613
|
freqs, xpos_scale = rotary_pos_emb
|
1515
1614
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1516
1615
|
|
1616
|
+
if partial_rotate_heads:
|
1617
|
+
q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
|
1618
|
+
k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
|
1619
|
+
|
1517
1620
|
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
|
1518
1621
|
|
1519
1622
|
if has_context:
|
@@ -1524,8 +1627,9 @@ class Attention(Module):
|
|
1524
1627
|
|
1525
1628
|
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
|
1526
1629
|
|
1527
|
-
if
|
1528
|
-
|
1630
|
+
if partial_rotate_heads:
|
1631
|
+
q = cat((q_rest, q), dim = 1)
|
1632
|
+
k = cat((k_rest, k), dim = 1)
|
1529
1633
|
|
1530
1634
|
input_mask = context_mask
|
1531
1635
|
|
@@ -1540,7 +1644,7 @@ class Attention(Module):
|
|
1540
1644
|
elif not exists(input_mask):
|
1541
1645
|
input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
|
1542
1646
|
else:
|
1543
|
-
input_mask =
|
1647
|
+
input_mask = cat((mem_mask, input_mask), dim = -1)
|
1544
1648
|
|
1545
1649
|
# i, j determined for relative positional bias, excluding memory key / values
|
1546
1650
|
|
@@ -1555,8 +1659,8 @@ class Attention(Module):
|
|
1555
1659
|
mem_k = l2norm(mem_k)
|
1556
1660
|
mem_k = mem_k * self.qk_norm_k_scale
|
1557
1661
|
|
1558
|
-
k =
|
1559
|
-
v =
|
1662
|
+
k = cat((mem_k, k), dim = -2)
|
1663
|
+
v = cat((mem_v, v), dim = -2)
|
1560
1664
|
|
1561
1665
|
if exists(input_mask):
|
1562
1666
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
@@ -1580,8 +1684,8 @@ class Attention(Module):
|
|
1580
1684
|
masks.append(~attn_mask)
|
1581
1685
|
|
1582
1686
|
if exists(self.max_attend_past):
|
1583
|
-
range_q =
|
1584
|
-
range_k =
|
1687
|
+
range_q = arange(j - i, j, device = device)
|
1688
|
+
range_k = arange(j, device = device)
|
1585
1689
|
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
|
1586
1690
|
max_attend_past_mask = dist > self.max_attend_past
|
1587
1691
|
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
|
@@ -1629,18 +1733,10 @@ class Attention(Module):
|
|
1629
1733
|
if self.laser:
|
1630
1734
|
out = log(out)
|
1631
1735
|
|
1632
|
-
# store the values for resformer
|
1736
|
+
# store the values for resformer
|
1633
1737
|
|
1634
1738
|
intermediates.values = orig_values
|
1635
1739
|
|
1636
|
-
if exists(value_residual) and self.neutreno_value_residual:
|
1637
|
-
out = out + diff_values
|
1638
|
-
|
1639
|
-
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
1640
|
-
|
1641
|
-
if exists(r):
|
1642
|
-
out = out * r + out
|
1643
|
-
|
1644
1740
|
# normformer scaling of heads
|
1645
1741
|
|
1646
1742
|
if head_scale:
|
@@ -1652,11 +1748,9 @@ class Attention(Module):
|
|
1652
1748
|
head_gate = self.to_v_head_gate(x)
|
1653
1749
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1654
1750
|
|
1655
|
-
#
|
1656
|
-
|
1657
|
-
out = rearrange(out, 'b h n d -> b n (h d)')
|
1751
|
+
# if exists hybrid module, must do a normalization
|
1658
1752
|
|
1659
|
-
|
1753
|
+
# hybrid module
|
1660
1754
|
|
1661
1755
|
if exists(self.hybrid_module):
|
1662
1756
|
|
@@ -1674,8 +1768,23 @@ class Attention(Module):
|
|
1674
1768
|
# handle hybrid out
|
1675
1769
|
|
1676
1770
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1771
|
+
|
1772
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1773
|
+
|
1774
|
+
if hybrid_out.ndim == 3:
|
1775
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1776
|
+
|
1777
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1778
|
+
|
1779
|
+
out = out_norm(out)
|
1780
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
|
+
|
1677
1782
|
out = 0.5 * (out + hybrid_out)
|
1678
1783
|
|
1784
|
+
# merge heads
|
1785
|
+
|
1786
|
+
out = self.merge_heads(out)
|
1787
|
+
|
1679
1788
|
# alphafold2 styled gating of the values
|
1680
1789
|
|
1681
1790
|
if exists(self.to_v_gate):
|
@@ -1747,8 +1856,6 @@ class AttentionLayers(Module):
|
|
1747
1856
|
sandwich_norm = False,
|
1748
1857
|
softclamp_output = False,
|
1749
1858
|
softclamp_output_value = 30.,
|
1750
|
-
resi_dual = False,
|
1751
|
-
resi_dual_scale = 1.,
|
1752
1859
|
zero_init_branch_output = False,
|
1753
1860
|
layer_dropout = 0.,
|
1754
1861
|
cross_attn_tokens_dropout = 0.,
|
@@ -1775,12 +1882,9 @@ class AttentionLayers(Module):
|
|
1775
1882
|
|
1776
1883
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1777
1884
|
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
|
1778
|
-
neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
|
1779
1885
|
|
1780
1886
|
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1781
1887
|
|
1782
|
-
add_value_residual |= neutreno_value_residual
|
1783
|
-
|
1784
1888
|
self.dim = dim
|
1785
1889
|
self.causal = causal
|
1786
1890
|
self.layers = ModuleList([])
|
@@ -1831,19 +1935,11 @@ class AttentionLayers(Module):
|
|
1831
1935
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1832
1936
|
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
1833
1937
|
|
1834
|
-
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1835
1938
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
1836
1939
|
|
1837
|
-
if resi_dual:
|
1838
|
-
pre_norm = False
|
1839
|
-
|
1840
1940
|
self.pre_norm = pre_norm
|
1841
1941
|
self.sandwich_norm = sandwich_norm
|
1842
1942
|
|
1843
|
-
self.resi_dual = resi_dual
|
1844
|
-
assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
|
1845
|
-
self.resi_dual_scale = resi_dual_scale
|
1846
|
-
|
1847
1943
|
self.residual_attn = residual_attn
|
1848
1944
|
self.cross_residual_attn = cross_residual_attn
|
1849
1945
|
assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
|
@@ -2002,7 +2098,7 @@ class AttentionLayers(Module):
|
|
2002
2098
|
|
2003
2099
|
# whether it has post norm
|
2004
2100
|
|
2005
|
-
self.final_norm = norm_fn() if pre_norm
|
2101
|
+
self.final_norm = norm_fn() if pre_norm else nn.Identity()
|
2006
2102
|
|
2007
2103
|
# whether unet or not
|
2008
2104
|
|
@@ -2175,7 +2271,7 @@ class AttentionLayers(Module):
|
|
2175
2271
|
# handle left padded sequences
|
2176
2272
|
|
2177
2273
|
if exists(seq_start_pos):
|
2178
|
-
seq_arange =
|
2274
|
+
seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
|
2179
2275
|
left_pad_mask = seq_arange >= seq_start_pos[..., None]
|
2180
2276
|
|
2181
2277
|
if exists(self_attn_kv_mask):
|
@@ -2193,7 +2289,7 @@ class AttentionLayers(Module):
|
|
2193
2289
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
2194
2290
|
|
2195
2291
|
if not exists(pos):
|
2196
|
-
pos =
|
2292
|
+
pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
2197
2293
|
|
2198
2294
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
2199
2295
|
|
@@ -2213,7 +2309,7 @@ class AttentionLayers(Module):
|
|
2213
2309
|
attn_cache = []
|
2214
2310
|
|
2215
2311
|
if exists(cache):
|
2216
|
-
assert
|
2312
|
+
assert self.causal and not any([*map(exists, (mask, attn_mask))])
|
2217
2313
|
|
2218
2314
|
if exists(context):
|
2219
2315
|
context = context[:, :0]
|
@@ -2231,13 +2327,7 @@ class AttentionLayers(Module):
|
|
2231
2327
|
is_multistream = streams > 1
|
2232
2328
|
|
2233
2329
|
if is_multistream:
|
2234
|
-
x =
|
2235
|
-
x = x + self.stream_emb
|
2236
|
-
x = rearrange(x, 'b n s d -> (b s) n d')
|
2237
|
-
|
2238
|
-
# outer residual - for resiDual paper
|
2239
|
-
|
2240
|
-
outer_residual = x * self.resi_dual_scale
|
2330
|
+
x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
|
2241
2331
|
|
2242
2332
|
# get layers to be executed
|
2243
2333
|
|
@@ -2359,9 +2449,6 @@ class AttentionLayers(Module):
|
|
2359
2449
|
if not exists(first_cross_attn_inter) and layer_type == 'c':
|
2360
2450
|
first_cross_attn_inter = inter
|
2361
2451
|
|
2362
|
-
if self.resi_dual:
|
2363
|
-
outer_residual = outer_residual + out * self.resi_dual_scale
|
2364
|
-
|
2365
2452
|
if exists(post_branch_norm):
|
2366
2453
|
out = post_branch_norm(out)
|
2367
2454
|
|
@@ -2395,10 +2482,7 @@ class AttentionLayers(Module):
|
|
2395
2482
|
if is_multistream:
|
2396
2483
|
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2397
2484
|
|
2398
|
-
|
2399
|
-
x = x + final_norm(outer_residual)
|
2400
|
-
else:
|
2401
|
-
x = final_norm(x)
|
2485
|
+
x = final_norm(x)
|
2402
2486
|
|
2403
2487
|
if not return_hiddens:
|
2404
2488
|
return x
|
@@ -2444,7 +2528,7 @@ class PrefixDecoder(AttentionLayers):
|
|
2444
2528
|
if isinstance(prefix_attn_len, int):
|
2445
2529
|
prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
|
2446
2530
|
|
2447
|
-
prefix_mask =
|
2531
|
+
prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
|
2448
2532
|
forwarded_mask = forwarded_mask | prefix_mask
|
2449
2533
|
|
2450
2534
|
if exists(attn_mask):
|
@@ -2773,13 +2857,13 @@ class TransformerWrapper(Module):
|
|
2773
2857
|
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
2774
2858
|
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
2775
2859
|
|
2776
|
-
x =
|
2860
|
+
x = cat((prepend_embeds, x), dim = -2)
|
2777
2861
|
|
2778
2862
|
if exists(prepend_mask) or exists(mask):
|
2779
2863
|
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
2780
2864
|
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
2781
2865
|
|
2782
|
-
mask =
|
2866
|
+
mask = cat((prepend_mask, mask), dim = -1)
|
2783
2867
|
|
2784
2868
|
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
2785
2869
|
|
@@ -2945,7 +3029,7 @@ class TransformerWrapper(Module):
|
|
2945
3029
|
|
2946
3030
|
if return_mems:
|
2947
3031
|
hiddens = intermediates.hiddens
|
2948
|
-
new_mems = [
|
3032
|
+
new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
2949
3033
|
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
2950
3034
|
|
2951
3035
|
if not return_intermediates:
|