x-transformers 1.44.8__py3-none-any.whl → 2.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/x_transformers.py +221 -137
- x_transformers-2.0.1.dist-info/METADATA +2419 -0
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info}/RECORD +5 -6
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info}/WHEEL +1 -2
- x_transformers-1.44.8.dist-info/METADATA +0 -30
- x_transformers-1.44.8.dist-info/top_level.txt +0 -1
- {x_transformers-1.44.8.dist-info → x_transformers-2.0.1.dist-info/licenses}/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -9,7 +9,7 @@ from packaging import version
|
|
9
9
|
import torch
|
10
10
|
from torch.amp import autocast
|
11
11
|
import torch.nn.functional as F
|
12
|
-
from torch import nn, einsum, Tensor
|
12
|
+
from torch import nn, einsum, Tensor, cat, stack, arange
|
13
13
|
from torch.utils._pytree import tree_flatten, tree_unflatten
|
14
14
|
from torch.nn import Module, ModuleList, ModuleDict
|
15
15
|
|
@@ -18,14 +18,22 @@ from collections import namedtuple
|
|
18
18
|
from contextlib import nullcontext
|
19
19
|
from dataclasses import dataclass
|
20
20
|
|
21
|
+
from loguru import logger
|
22
|
+
|
23
|
+
from x_transformers.attend import Attend, Intermediates
|
24
|
+
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
|
25
|
+
|
21
26
|
import einx
|
22
27
|
from einops.layers.torch import Rearrange
|
23
28
|
from einops import rearrange, repeat, reduce, pack, unpack
|
24
29
|
|
25
|
-
|
30
|
+
# einstein notation
|
26
31
|
|
27
|
-
|
28
|
-
|
32
|
+
# b - batch
|
33
|
+
# n - sequence
|
34
|
+
# d - feature dimension
|
35
|
+
# h - attention heads
|
36
|
+
# i, j - sequence (source, target)
|
29
37
|
|
30
38
|
# constants
|
31
39
|
|
@@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout):
|
|
220
228
|
num_keep = max(1, int(keep_prob * n))
|
221
229
|
keep_indices = logits.topk(num_keep, dim = 1).indices
|
222
230
|
|
223
|
-
batch_indices =
|
231
|
+
batch_indices = arange(b, device = device)
|
224
232
|
batch_indices = rearrange(batch_indices, 'b -> b 1')
|
225
233
|
|
226
234
|
seq = seq[batch_indices, keep_indices]
|
@@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout):
|
|
228
236
|
if exists(mask):
|
229
237
|
seq_counts = mask.sum(dim = -1)
|
230
238
|
seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
|
231
|
-
keep_mask =
|
239
|
+
keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
|
232
240
|
|
233
241
|
mask = mask[batch_indices, keep_indices] & keep_mask
|
234
242
|
|
@@ -274,7 +282,7 @@ class AbsolutePositionalEmbedding(Module):
|
|
274
282
|
assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
|
275
283
|
|
276
284
|
if not exists(pos):
|
277
|
-
pos =
|
285
|
+
pos = arange(seq_len, device = device)
|
278
286
|
|
279
287
|
if exists(seq_start_pos):
|
280
288
|
pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
|
@@ -290,7 +298,7 @@ class ScaledSinusoidalEmbedding(Module):
|
|
290
298
|
self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
|
291
299
|
|
292
300
|
half_dim = dim // 2
|
293
|
-
freq_seq =
|
301
|
+
freq_seq = arange(half_dim).float() / half_dim
|
294
302
|
inv_freq = theta ** -freq_seq
|
295
303
|
self.register_buffer('inv_freq', inv_freq, persistent = False)
|
296
304
|
|
@@ -298,13 +306,13 @@ class ScaledSinusoidalEmbedding(Module):
|
|
298
306
|
seq_len, device = x.shape[1], x.device
|
299
307
|
|
300
308
|
if not exists(pos):
|
301
|
-
pos =
|
309
|
+
pos = arange(seq_len, device = device)
|
302
310
|
|
303
311
|
if exists(seq_start_pos):
|
304
312
|
pos = pos - seq_start_pos[..., None]
|
305
313
|
|
306
314
|
emb = einsum('i, j -> i j', pos, self.inv_freq)
|
307
|
-
emb =
|
315
|
+
emb = cat((emb.sin(), emb.cos()), dim = -1)
|
308
316
|
return emb * self.scale
|
309
317
|
|
310
318
|
class RelativePositionBias(Module):
|
@@ -344,8 +352,8 @@ class RelativePositionBias(Module):
|
|
344
352
|
|
345
353
|
def forward(self, i, j):
|
346
354
|
device = self.device
|
347
|
-
q_pos =
|
348
|
-
k_pos =
|
355
|
+
q_pos = arange(j - i, j, dtype = torch.long, device = device)
|
356
|
+
k_pos = arange(j, dtype = torch.long, device = device)
|
349
357
|
rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
|
350
358
|
rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
|
351
359
|
values = self.relative_attention_bias(rp_bucket)
|
@@ -376,7 +384,7 @@ class CoPE(Module):
|
|
376
384
|
if not soft_onehot:
|
377
385
|
return
|
378
386
|
|
379
|
-
self.register_buffer('positions',
|
387
|
+
self.register_buffer('positions', arange(max_pos))
|
380
388
|
|
381
389
|
def forward(self, query, attn_logits):
|
382
390
|
|
@@ -445,13 +453,13 @@ class DynamicPositionBias(Module):
|
|
445
453
|
n, device = j, self.device
|
446
454
|
|
447
455
|
# get the (n x n) matrix of distances
|
448
|
-
seq_arange =
|
449
|
-
context_arange =
|
456
|
+
seq_arange = arange(n, device = device)
|
457
|
+
context_arange = arange(n, device = device)
|
450
458
|
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
|
451
459
|
indices += (n - 1)
|
452
460
|
|
453
461
|
# input to continuous positions MLP
|
454
|
-
pos =
|
462
|
+
pos = arange(-n + 1, n, device = device).float()
|
455
463
|
pos = rearrange(pos, '... -> ... 1')
|
456
464
|
|
457
465
|
if self.log_distance:
|
@@ -525,8 +533,8 @@ class AlibiPositionalBias(Module):
|
|
525
533
|
if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
|
526
534
|
return self.bias[..., -i:, -j:]
|
527
535
|
|
528
|
-
seq_arange =
|
529
|
-
context_arange =
|
536
|
+
seq_arange = arange(j - i, j, device = device)
|
537
|
+
context_arange = arange(j, device = device)
|
530
538
|
bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
|
531
539
|
|
532
540
|
bias = bias * self.slopes
|
@@ -642,7 +650,7 @@ class RotaryEmbedding(Module):
|
|
642
650
|
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
643
651
|
base *= base_rescale_factor ** (dim / (dim - 2))
|
644
652
|
|
645
|
-
inv_freq = 1. / (base ** (
|
653
|
+
inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
|
646
654
|
self.register_buffer('inv_freq', inv_freq)
|
647
655
|
|
648
656
|
assert interpolation_factor >= 1.
|
@@ -652,7 +660,7 @@ class RotaryEmbedding(Module):
|
|
652
660
|
self.register_buffer('scale', None)
|
653
661
|
return
|
654
662
|
|
655
|
-
scale = (
|
663
|
+
scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
656
664
|
|
657
665
|
self.scale_base = scale_base
|
658
666
|
self.register_buffer('scale', scale)
|
@@ -660,7 +668,7 @@ class RotaryEmbedding(Module):
|
|
660
668
|
def forward_from_seq_len(self, seq_len):
|
661
669
|
device = self.inv_freq.device
|
662
670
|
|
663
|
-
t =
|
671
|
+
t = arange(seq_len, device = device)
|
664
672
|
return self.forward(t)
|
665
673
|
|
666
674
|
@autocast('cuda', enabled = False)
|
@@ -671,7 +679,7 @@ class RotaryEmbedding(Module):
|
|
671
679
|
t = rearrange(t, 'n -> 1 n')
|
672
680
|
|
673
681
|
freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
|
674
|
-
freqs =
|
682
|
+
freqs = stack((freqs, freqs), dim = -1)
|
675
683
|
freqs = rearrange(freqs, '... d r -> ... (d r)')
|
676
684
|
|
677
685
|
if not exists(self.scale):
|
@@ -679,7 +687,7 @@ class RotaryEmbedding(Module):
|
|
679
687
|
|
680
688
|
power = (t - (max_pos // 2)) / self.scale_base
|
681
689
|
scale = self.scale ** rearrange(power, '... n -> ... n 1')
|
682
|
-
scale =
|
690
|
+
scale = stack((scale, scale), dim = -1)
|
683
691
|
scale = rearrange(scale, '... d r -> ... (d r)')
|
684
692
|
|
685
693
|
return freqs, scale
|
@@ -687,7 +695,7 @@ class RotaryEmbedding(Module):
|
|
687
695
|
def rotate_half(x):
|
688
696
|
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
689
697
|
x1, x2 = x.unbind(dim = -1)
|
690
|
-
x =
|
698
|
+
x = stack((-x2, x1), dim = -1)
|
691
699
|
return rearrange(x, '... d r -> ... (d r)')
|
692
700
|
|
693
701
|
@autocast('cuda', enabled = False)
|
@@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
|
|
703
711
|
# partial rotary embeddings, Wang et al. GPT-J
|
704
712
|
t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
|
705
713
|
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
706
|
-
out =
|
714
|
+
out = cat((t, t_unrotated), dim = -1)
|
707
715
|
|
708
716
|
return out.type(orig_dtype)
|
709
717
|
|
@@ -833,6 +841,15 @@ class SimpleRMSNorm(Module):
|
|
833
841
|
def forward(self, x):
|
834
842
|
return F.normalize(x, dim = -1) * self.scale
|
835
843
|
|
844
|
+
class MultiheadRMSNorm(Module):
|
845
|
+
def __init__(self, dim, heads):
|
846
|
+
super().__init__()
|
847
|
+
self.rmsnorm = SimpleRMSNorm(dim)
|
848
|
+
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
|
849
|
+
|
850
|
+
def forward(self, x):
|
851
|
+
return self.rmsnorm(x) * (self.gamma + 1.)
|
852
|
+
|
836
853
|
# residual and residual gates
|
837
854
|
|
838
855
|
class Residual(Module):
|
@@ -904,7 +921,7 @@ class HyperConnection(Module):
|
|
904
921
|
init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
|
905
922
|
init_alpha0[layer_index % num_residual_streams, :] = 1.
|
906
923
|
|
907
|
-
self.static_alpha = nn.Parameter(
|
924
|
+
self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
|
908
925
|
|
909
926
|
self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
|
910
927
|
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
|
@@ -973,7 +990,7 @@ class ShiftTokens(Module):
|
|
973
990
|
splitted = x.split(feats_per_shift, dim = -1)
|
974
991
|
segments_to_shift, rest = splitted[:segments], splitted[segments:]
|
975
992
|
segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
|
976
|
-
x =
|
993
|
+
x = cat((*segments_to_shift, *rest), dim = -1)
|
977
994
|
return self.fn(x, **kwargs)
|
978
995
|
|
979
996
|
class FoldAxially(Module):
|
@@ -1080,7 +1097,7 @@ class ConcatCombine(Module):
|
|
1080
1097
|
|
1081
1098
|
def forward(self, x, prev_layers: list[Tensor]):
|
1082
1099
|
skip = prev_layers[self.prev_layer_ind]
|
1083
|
-
concatted_skip =
|
1100
|
+
concatted_skip = cat((skip, x), dim = -1)
|
1084
1101
|
return self.combine(concatted_skip)
|
1085
1102
|
|
1086
1103
|
# feedforward
|
@@ -1189,12 +1206,10 @@ class Attention(Module):
|
|
1189
1206
|
hybrid_fold_axial_dim: int | None = None,
|
1190
1207
|
one_kv_head = False,
|
1191
1208
|
kv_heads = None,
|
1192
|
-
shared_kv = False,
|
1193
1209
|
value_dim_head = None,
|
1194
1210
|
dim_out = None,
|
1195
|
-
tensor_product = False, # https://arxiv.org/abs/2208.06061
|
1196
1211
|
add_zero_kv = False, # same as add_zero_attn in pytorch
|
1197
|
-
|
1212
|
+
rotate_num_heads = None,
|
1198
1213
|
data_dependent_alibi = False,
|
1199
1214
|
data_dependent_alibi_per_row = False,
|
1200
1215
|
data_dependent_alibi_per_row_dim_head = 8,
|
@@ -1205,12 +1220,15 @@ class Attention(Module):
|
|
1205
1220
|
cope_talking_heads = False,
|
1206
1221
|
softclamp_logits = False,
|
1207
1222
|
logit_softclamp_value = 50.,
|
1208
|
-
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1209
|
-
neutreno_alpha = 0.4,
|
1210
1223
|
learned_value_residual_mix = False,
|
1211
|
-
laser = False,
|
1224
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1212
1225
|
laser_softclamp_value = 15.,
|
1213
1226
|
qkv_receive_diff_residuals = False,
|
1227
|
+
use_latent_q = False,
|
1228
|
+
dim_latent_q = None,
|
1229
|
+
use_latent_kv = False,
|
1230
|
+
dim_latent_kv = None,
|
1231
|
+
latent_rope_subheads = None,
|
1214
1232
|
onnxable = False,
|
1215
1233
|
attend_sdp_kwargs: dict = dict(
|
1216
1234
|
enable_flash = True,
|
@@ -1242,13 +1260,51 @@ class Attention(Module):
|
|
1242
1260
|
v_dim = value_dim_head * kv_heads
|
1243
1261
|
out_dim = value_dim_head * heads
|
1244
1262
|
|
1245
|
-
|
1246
|
-
|
1263
|
+
# determine input dimensions to qkv based on whether intermediate latent q and kv are being used
|
1264
|
+
# for eventually supporting multi-latent attention (MLA)
|
1265
|
+
|
1266
|
+
self.to_latent_q = None
|
1267
|
+
self.to_latent_kv = None
|
1268
|
+
self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
|
1269
|
+
|
1270
|
+
dim_q_input = dim
|
1271
|
+
dim_kv_input = dim_kv
|
1272
|
+
|
1273
|
+
if use_latent_q:
|
1274
|
+
assert exists(dim_latent_q)
|
1275
|
+
self.to_latent_q = LinearNoBias(dim, dim_latent_q)
|
1276
|
+
dim_q_input = dim_latent_q
|
1277
|
+
|
1278
|
+
if use_latent_kv:
|
1279
|
+
assert exists(dim_latent_kv)
|
1280
|
+
self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
|
1281
|
+
dim_kv_input = dim_latent_kv
|
1282
|
+
|
1283
|
+
if exists(latent_rope_subheads):
|
1284
|
+
assert not exists(rotate_num_heads)
|
1285
|
+
rotate_num_heads = latent_rope_subheads
|
1286
|
+
|
1287
|
+
k_dim = dim_head * (kv_heads - latent_rope_subheads)
|
1247
1288
|
|
1248
|
-
|
1289
|
+
self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
|
1290
|
+
self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
|
1249
1291
|
|
1250
|
-
|
1251
|
-
self.
|
1292
|
+
self.use_latent_q = use_latent_q
|
1293
|
+
self.use_latent_kv = use_latent_kv
|
1294
|
+
|
1295
|
+
# query key projection
|
1296
|
+
|
1297
|
+
self.to_q = LinearNoBias(dim_q_input, q_dim)
|
1298
|
+
self.to_k = LinearNoBias(dim_kv_input, k_dim)
|
1299
|
+
self.to_v = LinearNoBias(dim_kv_input, v_dim)
|
1300
|
+
|
1301
|
+
# split and merge of attention heads
|
1302
|
+
|
1303
|
+
self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
|
1304
|
+
self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
|
1305
|
+
self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
|
1306
|
+
|
1307
|
+
self.merge_heads = Rearrange('b h n d -> b n (h d)')
|
1252
1308
|
|
1253
1309
|
# whether qkv receives different residual stream combinations from hyper connections
|
1254
1310
|
|
@@ -1259,15 +1315,6 @@ class Attention(Module):
|
|
1259
1315
|
self.laser = laser
|
1260
1316
|
self.laser_softclamp_value = laser_softclamp_value
|
1261
1317
|
|
1262
|
-
# relations projection from tp-attention
|
1263
|
-
|
1264
|
-
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
1265
|
-
|
1266
|
-
# the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
|
1267
|
-
|
1268
|
-
self.neutreno_value_residual = neutreno_value_residual
|
1269
|
-
self.neutreno_alpha = neutreno_alpha
|
1270
|
-
|
1271
1318
|
# add GLU gating for aggregated values, from alphafold2
|
1272
1319
|
|
1273
1320
|
self.to_v_gate = None
|
@@ -1393,12 +1440,22 @@ class Attention(Module):
|
|
1393
1440
|
|
1394
1441
|
# hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
|
1395
1442
|
|
1443
|
+
hybrid_mix = None
|
1444
|
+
hybrid_norms = None
|
1396
1445
|
hybrid_module = maybe(deepcopy)(hybrid_module)
|
1397
1446
|
|
1398
1447
|
if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
|
1399
1448
|
hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
|
1449
|
+
hybrid_mix = LinearNoBias(dim, heads)
|
1450
|
+
|
1451
|
+
hybrid_norms = ModuleList([
|
1452
|
+
MultiheadRMSNorm(dim_head, heads = heads),
|
1453
|
+
MultiheadRMSNorm(dim_head, heads = heads)
|
1454
|
+
])
|
1400
1455
|
|
1401
1456
|
self.hybrid_module = hybrid_module
|
1457
|
+
self.hybrid_norms = hybrid_norms
|
1458
|
+
self.hybrid_mix = hybrid_mix
|
1402
1459
|
self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
|
1403
1460
|
|
1404
1461
|
# output dimension by default same as input, but can be overridden
|
@@ -1406,9 +1463,15 @@ class Attention(Module):
|
|
1406
1463
|
dim_out = default(dim_out, dim)
|
1407
1464
|
self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
|
1408
1465
|
|
1409
|
-
#
|
1466
|
+
# the number of attention heads to rotate, for decoupled rope in multi-latent attention
|
1467
|
+
|
1468
|
+
rotate_num_heads = default(rotate_num_heads, heads)
|
1410
1469
|
|
1411
|
-
|
1470
|
+
assert 0 < rotate_num_heads <= heads
|
1471
|
+
is_partial_rotate_heads = rotate_num_heads < heads
|
1472
|
+
assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
|
1473
|
+
|
1474
|
+
self.rotate_num_heads = rotate_num_heads
|
1412
1475
|
|
1413
1476
|
# whether parent can kv cache
|
1414
1477
|
|
@@ -1438,47 +1501,79 @@ class Attention(Module):
|
|
1438
1501
|
cache: Intermediates | None = None,
|
1439
1502
|
value_residual = None
|
1440
1503
|
):
|
1441
|
-
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
|
1504
|
+
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
|
1505
|
+
|
1506
|
+
# an interesting possibility with hyper connections
|
1507
|
+
# having queries, keys, values be routed from different layers
|
1442
1508
|
|
1443
1509
|
assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
|
1444
1510
|
|
1445
1511
|
if qkv_receive_diff_residuals:
|
1446
|
-
assert
|
1512
|
+
assert x.ndim == 4 and x.shape[0] == 3
|
1447
1513
|
|
1448
1514
|
q_input, k_input, v_input = x
|
1449
1515
|
else:
|
1450
1516
|
kv_input = default(context, x)
|
1451
|
-
|
1452
|
-
q_input = x
|
1453
|
-
k_input = kv_input
|
1454
|
-
v_input = kv_input
|
1455
|
-
r_input = x
|
1517
|
+
q_input, k_input, v_input = x, kv_input, kv_input
|
1456
1518
|
|
1457
1519
|
if exists(mem):
|
1458
1520
|
k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
|
1459
1521
|
v_input, _ = pack([mem, v_input], 'b * d')
|
1460
1522
|
|
1523
|
+
# multi-latent attention logic
|
1524
|
+
# https://arxiv.org/abs/2405.04434 - Deepseek-AI team
|
1525
|
+
|
1526
|
+
k_sub_heads = None # the rotateable subheads of keys derived from base sequence
|
1527
|
+
|
1528
|
+
if self.use_latent_q:
|
1529
|
+
q_input = self.to_latent_q(q_input)
|
1530
|
+
|
1531
|
+
if is_multi_latent_attn:
|
1532
|
+
assert not qkv_receive_diff_residuals
|
1533
|
+
needs_k_sub_heads = exists(self.to_rotateable_k)
|
1534
|
+
|
1535
|
+
latent_kv_input = self.to_latent_kv(k_input)
|
1536
|
+
|
1537
|
+
if needs_k_sub_heads:
|
1538
|
+
rotateable_k = self.to_rotateable_k(k_input)
|
1539
|
+
k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
|
1540
|
+
|
1541
|
+
if exists(cache):
|
1542
|
+
cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
|
1543
|
+
latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
|
1544
|
+
|
1545
|
+
if exists(maybe_cached_k_sub_heads):
|
1546
|
+
k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
|
1547
|
+
|
1548
|
+
if return_intermediates:
|
1549
|
+
cached_kv = (latent_kv_input, k_sub_heads)
|
1550
|
+
|
1551
|
+
k_input = v_input = latent_kv_input
|
1552
|
+
|
1553
|
+
# query, key, value projection
|
1554
|
+
|
1461
1555
|
q = self.to_q(q_input)
|
1462
1556
|
k = self.to_k(k_input)
|
1463
|
-
v = self.to_v(v_input)
|
1464
|
-
|
1557
|
+
v = self.to_v(v_input)
|
1558
|
+
|
1559
|
+
q = self.split_q_heads(q)
|
1560
|
+
k = self.split_k_heads(k)
|
1561
|
+
v = self.split_v_heads(v)
|
1465
1562
|
|
1466
|
-
|
1563
|
+
# take care of decoupled rope from multi-latent attention
|
1467
1564
|
|
1468
|
-
|
1565
|
+
if exists(k_sub_heads):
|
1566
|
+
k = cat((k, k_sub_heads), dim = 1)
|
1469
1567
|
|
1470
|
-
# if previous values passed in for residual, either invoke resformer
|
1568
|
+
# if previous values passed in for residual, either invoke resformer
|
1471
1569
|
|
1472
1570
|
orig_values = v
|
1473
1571
|
|
1572
|
+
# https://arxiv.org/abs/2410.17897v1
|
1573
|
+
|
1474
1574
|
if exists(value_residual):
|
1475
|
-
|
1476
|
-
|
1477
|
-
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1478
|
-
else:
|
1479
|
-
# https://arxiv.org/abs/2410.17897v1
|
1480
|
-
value_residual_mix = self.to_value_residual_mix(q_input)
|
1481
|
-
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
|
1575
|
+
value_residual_mix = self.to_value_residual_mix(q_input)
|
1576
|
+
v = value_residual.lerp(v, value_residual_mix)
|
1482
1577
|
|
1483
1578
|
# qk normalization
|
1484
1579
|
|
@@ -1492,28 +1587,36 @@ class Attention(Module):
|
|
1492
1587
|
|
1493
1588
|
# take care of caching
|
1494
1589
|
|
1495
|
-
if
|
1496
|
-
|
1590
|
+
if not is_multi_latent_attn:
|
1591
|
+
if exists(cache):
|
1592
|
+
ck, cv = cache.cached_kv
|
1497
1593
|
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1594
|
+
if exists(mem):
|
1595
|
+
mk, k = unpack(k, mem_packed_shape, 'b h * d')
|
1596
|
+
mv, v = unpack(v, mem_packed_shape, 'b h * d')
|
1501
1597
|
|
1502
|
-
|
1503
|
-
|
1598
|
+
k = cat((ck, k), dim = -2)
|
1599
|
+
v = cat((cv, v), dim = -2)
|
1504
1600
|
|
1505
|
-
|
1506
|
-
|
1507
|
-
|
1601
|
+
if exists(mem):
|
1602
|
+
k = cat((mk, k), dim = -2)
|
1603
|
+
v = cat((mv, v), dim = -2)
|
1508
1604
|
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1605
|
+
if return_intermediates:
|
1606
|
+
mem_len = mem.shape[-2] if exists(mem) else 0
|
1607
|
+
cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
|
1512
1608
|
|
1513
1609
|
if exists(rotary_pos_emb):
|
1610
|
+
rotate_num_heads = self.rotate_num_heads
|
1611
|
+
partial_rotate_heads = rotate_num_heads < h
|
1612
|
+
|
1514
1613
|
freqs, xpos_scale = rotary_pos_emb
|
1515
1614
|
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
|
1516
1615
|
|
1616
|
+
if partial_rotate_heads:
|
1617
|
+
q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
|
1618
|
+
k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
|
1619
|
+
|
1517
1620
|
q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
|
1518
1621
|
|
1519
1622
|
if has_context:
|
@@ -1524,8 +1627,9 @@ class Attention(Module):
|
|
1524
1627
|
|
1525
1628
|
k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
|
1526
1629
|
|
1527
|
-
if
|
1528
|
-
|
1630
|
+
if partial_rotate_heads:
|
1631
|
+
q = cat((q_rest, q), dim = 1)
|
1632
|
+
k = cat((k_rest, k), dim = 1)
|
1529
1633
|
|
1530
1634
|
input_mask = context_mask
|
1531
1635
|
|
@@ -1540,7 +1644,7 @@ class Attention(Module):
|
|
1540
1644
|
elif not exists(input_mask):
|
1541
1645
|
input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
|
1542
1646
|
else:
|
1543
|
-
input_mask =
|
1647
|
+
input_mask = cat((mem_mask, input_mask), dim = -1)
|
1544
1648
|
|
1545
1649
|
# i, j determined for relative positional bias, excluding memory key / values
|
1546
1650
|
|
@@ -1555,8 +1659,8 @@ class Attention(Module):
|
|
1555
1659
|
mem_k = l2norm(mem_k)
|
1556
1660
|
mem_k = mem_k * self.qk_norm_k_scale
|
1557
1661
|
|
1558
|
-
k =
|
1559
|
-
v =
|
1662
|
+
k = cat((mem_k, k), dim = -2)
|
1663
|
+
v = cat((mem_v, v), dim = -2)
|
1560
1664
|
|
1561
1665
|
if exists(input_mask):
|
1562
1666
|
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
|
@@ -1580,8 +1684,8 @@ class Attention(Module):
|
|
1580
1684
|
masks.append(~attn_mask)
|
1581
1685
|
|
1582
1686
|
if exists(self.max_attend_past):
|
1583
|
-
range_q =
|
1584
|
-
range_k =
|
1687
|
+
range_q = arange(j - i, j, device = device)
|
1688
|
+
range_k = arange(j, device = device)
|
1585
1689
|
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
|
1586
1690
|
max_attend_past_mask = dist > self.max_attend_past
|
1587
1691
|
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
|
@@ -1629,18 +1733,10 @@ class Attention(Module):
|
|
1629
1733
|
if self.laser:
|
1630
1734
|
out = log(out)
|
1631
1735
|
|
1632
|
-
# store the values for resformer
|
1736
|
+
# store the values for resformer
|
1633
1737
|
|
1634
1738
|
intermediates.values = orig_values
|
1635
1739
|
|
1636
|
-
if exists(value_residual) and self.neutreno_value_residual:
|
1637
|
-
out = out + diff_values
|
1638
|
-
|
1639
|
-
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
1640
|
-
|
1641
|
-
if exists(r):
|
1642
|
-
out = out * r + out
|
1643
|
-
|
1644
1740
|
# normformer scaling of heads
|
1645
1741
|
|
1646
1742
|
if head_scale:
|
@@ -1652,11 +1748,9 @@ class Attention(Module):
|
|
1652
1748
|
head_gate = self.to_v_head_gate(x)
|
1653
1749
|
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
|
1654
1750
|
|
1655
|
-
#
|
1656
|
-
|
1657
|
-
out = rearrange(out, 'b h n d -> b n (h d)')
|
1751
|
+
# if exists hybrid module, must do a normalization
|
1658
1752
|
|
1659
|
-
|
1753
|
+
# hybrid module
|
1660
1754
|
|
1661
1755
|
if exists(self.hybrid_module):
|
1662
1756
|
|
@@ -1674,8 +1768,23 @@ class Attention(Module):
|
|
1674
1768
|
# handle hybrid out
|
1675
1769
|
|
1676
1770
|
(hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
|
1771
|
+
|
1772
|
+
# handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
|
1773
|
+
|
1774
|
+
if hybrid_out.ndim == 3:
|
1775
|
+
hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
|
1776
|
+
|
1777
|
+
out_norm, hybrid_out_norm = self.hybrid_norms
|
1778
|
+
|
1779
|
+
out = out_norm(out)
|
1780
|
+
hybrid_out = hybrid_out_norm(hybrid_out)
|
1781
|
+
|
1677
1782
|
out = 0.5 * (out + hybrid_out)
|
1678
1783
|
|
1784
|
+
# merge heads
|
1785
|
+
|
1786
|
+
out = self.merge_heads(out)
|
1787
|
+
|
1679
1788
|
# alphafold2 styled gating of the values
|
1680
1789
|
|
1681
1790
|
if exists(self.to_v_gate):
|
@@ -1747,8 +1856,6 @@ class AttentionLayers(Module):
|
|
1747
1856
|
sandwich_norm = False,
|
1748
1857
|
softclamp_output = False,
|
1749
1858
|
softclamp_output_value = 30.,
|
1750
|
-
resi_dual = False,
|
1751
|
-
resi_dual_scale = 1.,
|
1752
1859
|
zero_init_branch_output = False,
|
1753
1860
|
layer_dropout = 0.,
|
1754
1861
|
cross_attn_tokens_dropout = 0.,
|
@@ -1775,12 +1882,9 @@ class AttentionLayers(Module):
|
|
1775
1882
|
|
1776
1883
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1777
1884
|
data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
|
1778
|
-
neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
|
1779
1885
|
|
1780
1886
|
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1781
1887
|
|
1782
|
-
add_value_residual |= neutreno_value_residual
|
1783
|
-
|
1784
1888
|
self.dim = dim
|
1785
1889
|
self.causal = causal
|
1786
1890
|
self.layers = ModuleList([])
|
@@ -1831,19 +1935,11 @@ class AttentionLayers(Module):
|
|
1831
1935
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1832
1936
|
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
1833
1937
|
|
1834
|
-
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1835
1938
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
1836
1939
|
|
1837
|
-
if resi_dual:
|
1838
|
-
pre_norm = False
|
1839
|
-
|
1840
1940
|
self.pre_norm = pre_norm
|
1841
1941
|
self.sandwich_norm = sandwich_norm
|
1842
1942
|
|
1843
|
-
self.resi_dual = resi_dual
|
1844
|
-
assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
|
1845
|
-
self.resi_dual_scale = resi_dual_scale
|
1846
|
-
|
1847
1943
|
self.residual_attn = residual_attn
|
1848
1944
|
self.cross_residual_attn = cross_residual_attn
|
1849
1945
|
assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
|
@@ -2002,7 +2098,7 @@ class AttentionLayers(Module):
|
|
2002
2098
|
|
2003
2099
|
# whether it has post norm
|
2004
2100
|
|
2005
|
-
self.final_norm = norm_fn() if pre_norm
|
2101
|
+
self.final_norm = norm_fn() if pre_norm else nn.Identity()
|
2006
2102
|
|
2007
2103
|
# whether unet or not
|
2008
2104
|
|
@@ -2175,7 +2271,7 @@ class AttentionLayers(Module):
|
|
2175
2271
|
# handle left padded sequences
|
2176
2272
|
|
2177
2273
|
if exists(seq_start_pos):
|
2178
|
-
seq_arange =
|
2274
|
+
seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
|
2179
2275
|
left_pad_mask = seq_arange >= seq_start_pos[..., None]
|
2180
2276
|
|
2181
2277
|
if exists(self_attn_kv_mask):
|
@@ -2193,7 +2289,7 @@ class AttentionLayers(Module):
|
|
2193
2289
|
mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
|
2194
2290
|
|
2195
2291
|
if not exists(pos):
|
2196
|
-
pos =
|
2292
|
+
pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
|
2197
2293
|
|
2198
2294
|
rotary_pos_emb = self.rotary_pos_emb(pos)
|
2199
2295
|
|
@@ -2213,7 +2309,7 @@ class AttentionLayers(Module):
|
|
2213
2309
|
attn_cache = []
|
2214
2310
|
|
2215
2311
|
if exists(cache):
|
2216
|
-
assert
|
2312
|
+
assert self.causal and not any([*map(exists, (mask, attn_mask))])
|
2217
2313
|
|
2218
2314
|
if exists(context):
|
2219
2315
|
context = context[:, :0]
|
@@ -2231,13 +2327,7 @@ class AttentionLayers(Module):
|
|
2231
2327
|
is_multistream = streams > 1
|
2232
2328
|
|
2233
2329
|
if is_multistream:
|
2234
|
-
x =
|
2235
|
-
x = x + self.stream_emb
|
2236
|
-
x = rearrange(x, 'b n s d -> (b s) n d')
|
2237
|
-
|
2238
|
-
# outer residual - for resiDual paper
|
2239
|
-
|
2240
|
-
outer_residual = x * self.resi_dual_scale
|
2330
|
+
x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
|
2241
2331
|
|
2242
2332
|
# get layers to be executed
|
2243
2333
|
|
@@ -2359,9 +2449,6 @@ class AttentionLayers(Module):
|
|
2359
2449
|
if not exists(first_cross_attn_inter) and layer_type == 'c':
|
2360
2450
|
first_cross_attn_inter = inter
|
2361
2451
|
|
2362
|
-
if self.resi_dual:
|
2363
|
-
outer_residual = outer_residual + out * self.resi_dual_scale
|
2364
|
-
|
2365
2452
|
if exists(post_branch_norm):
|
2366
2453
|
out = post_branch_norm(out)
|
2367
2454
|
|
@@ -2395,10 +2482,7 @@ class AttentionLayers(Module):
|
|
2395
2482
|
if is_multistream:
|
2396
2483
|
x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
|
2397
2484
|
|
2398
|
-
|
2399
|
-
x = x + final_norm(outer_residual)
|
2400
|
-
else:
|
2401
|
-
x = final_norm(x)
|
2485
|
+
x = final_norm(x)
|
2402
2486
|
|
2403
2487
|
if not return_hiddens:
|
2404
2488
|
return x
|
@@ -2444,7 +2528,7 @@ class PrefixDecoder(AttentionLayers):
|
|
2444
2528
|
if isinstance(prefix_attn_len, int):
|
2445
2529
|
prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
|
2446
2530
|
|
2447
|
-
prefix_mask =
|
2531
|
+
prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
|
2448
2532
|
forwarded_mask = forwarded_mask | prefix_mask
|
2449
2533
|
|
2450
2534
|
if exists(attn_mask):
|
@@ -2773,13 +2857,13 @@ class TransformerWrapper(Module):
|
|
2773
2857
|
prepend_seq, prepend_dim = prepend_embeds.shape[1:]
|
2774
2858
|
assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
|
2775
2859
|
|
2776
|
-
x =
|
2860
|
+
x = cat((prepend_embeds, x), dim = -2)
|
2777
2861
|
|
2778
2862
|
if exists(prepend_mask) or exists(mask):
|
2779
2863
|
mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
|
2780
2864
|
prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
|
2781
2865
|
|
2782
|
-
mask =
|
2866
|
+
mask = cat((prepend_mask, mask), dim = -1)
|
2783
2867
|
|
2784
2868
|
# whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
|
2785
2869
|
|
@@ -2945,7 +3029,7 @@ class TransformerWrapper(Module):
|
|
2945
3029
|
|
2946
3030
|
if return_mems:
|
2947
3031
|
hiddens = intermediates.hiddens
|
2948
|
-
new_mems = [
|
3032
|
+
new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
|
2949
3033
|
new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
|
2950
3034
|
|
2951
3035
|
if not return_intermediates:
|