x-transformers 1.44.8__py3-none-any.whl → 2.0.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor
12
+ from torch import nn, einsum, Tensor, cat, stack, arange
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -18,14 +18,22 @@ from collections import namedtuple
18
18
  from contextlib import nullcontext
19
19
  from dataclasses import dataclass
20
20
 
21
+ from loguru import logger
22
+
23
+ from x_transformers.attend import Attend, Intermediates
24
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
+
21
26
  import einx
22
27
  from einops.layers.torch import Rearrange
23
28
  from einops import rearrange, repeat, reduce, pack, unpack
24
29
 
25
- from loguru import logger
30
+ # einstein notation
26
31
 
27
- from x_transformers.attend import Attend, Intermediates
28
- from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
32
+ # b - batch
33
+ # n - sequence
34
+ # d - feature dimension
35
+ # h - attention heads
36
+ # i, j - sequence (source, target)
29
37
 
30
38
  # constants
31
39
 
@@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout):
220
228
  num_keep = max(1, int(keep_prob * n))
221
229
  keep_indices = logits.topk(num_keep, dim = 1).indices
222
230
 
223
- batch_indices = torch.arange(b, device = device)
231
+ batch_indices = arange(b, device = device)
224
232
  batch_indices = rearrange(batch_indices, 'b -> b 1')
225
233
 
226
234
  seq = seq[batch_indices, keep_indices]
@@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout):
228
236
  if exists(mask):
229
237
  seq_counts = mask.sum(dim = -1)
230
238
  seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
231
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
239
+ keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
232
240
 
233
241
  mask = mask[batch_indices, keep_indices] & keep_mask
234
242
 
@@ -274,7 +282,7 @@ class AbsolutePositionalEmbedding(Module):
274
282
  assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
275
283
 
276
284
  if not exists(pos):
277
- pos = torch.arange(seq_len, device = device)
285
+ pos = arange(seq_len, device = device)
278
286
 
279
287
  if exists(seq_start_pos):
280
288
  pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
@@ -290,7 +298,7 @@ class ScaledSinusoidalEmbedding(Module):
290
298
  self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
291
299
 
292
300
  half_dim = dim // 2
293
- freq_seq = torch.arange(half_dim).float() / half_dim
301
+ freq_seq = arange(half_dim).float() / half_dim
294
302
  inv_freq = theta ** -freq_seq
295
303
  self.register_buffer('inv_freq', inv_freq, persistent = False)
296
304
 
@@ -298,13 +306,13 @@ class ScaledSinusoidalEmbedding(Module):
298
306
  seq_len, device = x.shape[1], x.device
299
307
 
300
308
  if not exists(pos):
301
- pos = torch.arange(seq_len, device = device)
309
+ pos = arange(seq_len, device = device)
302
310
 
303
311
  if exists(seq_start_pos):
304
312
  pos = pos - seq_start_pos[..., None]
305
313
 
306
314
  emb = einsum('i, j -> i j', pos, self.inv_freq)
307
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
315
+ emb = cat((emb.sin(), emb.cos()), dim = -1)
308
316
  return emb * self.scale
309
317
 
310
318
  class RelativePositionBias(Module):
@@ -344,8 +352,8 @@ class RelativePositionBias(Module):
344
352
 
345
353
  def forward(self, i, j):
346
354
  device = self.device
347
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
348
- k_pos = torch.arange(j, dtype = torch.long, device = device)
355
+ q_pos = arange(j - i, j, dtype = torch.long, device = device)
356
+ k_pos = arange(j, dtype = torch.long, device = device)
349
357
  rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
350
358
  rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
351
359
  values = self.relative_attention_bias(rp_bucket)
@@ -376,7 +384,7 @@ class CoPE(Module):
376
384
  if not soft_onehot:
377
385
  return
378
386
 
379
- self.register_buffer('positions', torch.arange(max_pos))
387
+ self.register_buffer('positions', arange(max_pos))
380
388
 
381
389
  def forward(self, query, attn_logits):
382
390
 
@@ -445,13 +453,13 @@ class DynamicPositionBias(Module):
445
453
  n, device = j, self.device
446
454
 
447
455
  # get the (n x n) matrix of distances
448
- seq_arange = torch.arange(n, device = device)
449
- context_arange = torch.arange(n, device = device)
456
+ seq_arange = arange(n, device = device)
457
+ context_arange = arange(n, device = device)
450
458
  indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
451
459
  indices += (n - 1)
452
460
 
453
461
  # input to continuous positions MLP
454
- pos = torch.arange(-n + 1, n, device = device).float()
462
+ pos = arange(-n + 1, n, device = device).float()
455
463
  pos = rearrange(pos, '... -> ... 1')
456
464
 
457
465
  if self.log_distance:
@@ -525,8 +533,8 @@ class AlibiPositionalBias(Module):
525
533
  if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
526
534
  return self.bias[..., -i:, -j:]
527
535
 
528
- seq_arange = torch.arange(j - i, j, device = device)
529
- context_arange = torch.arange(j, device = device)
536
+ seq_arange = arange(j - i, j, device = device)
537
+ context_arange = arange(j, device = device)
530
538
  bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
531
539
 
532
540
  bias = bias * self.slopes
@@ -642,7 +650,7 @@ class RotaryEmbedding(Module):
642
650
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
643
651
  base *= base_rescale_factor ** (dim / (dim - 2))
644
652
 
645
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
653
+ inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
646
654
  self.register_buffer('inv_freq', inv_freq)
647
655
 
648
656
  assert interpolation_factor >= 1.
@@ -652,7 +660,7 @@ class RotaryEmbedding(Module):
652
660
  self.register_buffer('scale', None)
653
661
  return
654
662
 
655
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
663
+ scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
656
664
 
657
665
  self.scale_base = scale_base
658
666
  self.register_buffer('scale', scale)
@@ -660,7 +668,7 @@ class RotaryEmbedding(Module):
660
668
  def forward_from_seq_len(self, seq_len):
661
669
  device = self.inv_freq.device
662
670
 
663
- t = torch.arange(seq_len, device = device)
671
+ t = arange(seq_len, device = device)
664
672
  return self.forward(t)
665
673
 
666
674
  @autocast('cuda', enabled = False)
@@ -671,7 +679,7 @@ class RotaryEmbedding(Module):
671
679
  t = rearrange(t, 'n -> 1 n')
672
680
 
673
681
  freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
674
- freqs = torch.stack((freqs, freqs), dim = -1)
682
+ freqs = stack((freqs, freqs), dim = -1)
675
683
  freqs = rearrange(freqs, '... d r -> ... (d r)')
676
684
 
677
685
  if not exists(self.scale):
@@ -679,7 +687,7 @@ class RotaryEmbedding(Module):
679
687
 
680
688
  power = (t - (max_pos // 2)) / self.scale_base
681
689
  scale = self.scale ** rearrange(power, '... n -> ... n 1')
682
- scale = torch.stack((scale, scale), dim = -1)
690
+ scale = stack((scale, scale), dim = -1)
683
691
  scale = rearrange(scale, '... d r -> ... (d r)')
684
692
 
685
693
  return freqs, scale
@@ -687,7 +695,7 @@ class RotaryEmbedding(Module):
687
695
  def rotate_half(x):
688
696
  x = rearrange(x, '... (d r) -> ... d r', r = 2)
689
697
  x1, x2 = x.unbind(dim = -1)
690
- x = torch.stack((-x2, x1), dim = -1)
698
+ x = stack((-x2, x1), dim = -1)
691
699
  return rearrange(x, '... d r -> ... (d r)')
692
700
 
693
701
  @autocast('cuda', enabled = False)
@@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
703
711
  # partial rotary embeddings, Wang et al. GPT-J
704
712
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
705
713
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
706
- out = torch.cat((t, t_unrotated), dim = -1)
714
+ out = cat((t, t_unrotated), dim = -1)
707
715
 
708
716
  return out.type(orig_dtype)
709
717
 
@@ -833,6 +841,15 @@ class SimpleRMSNorm(Module):
833
841
  def forward(self, x):
834
842
  return F.normalize(x, dim = -1) * self.scale
835
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
836
853
  # residual and residual gates
837
854
 
838
855
  class Residual(Module):
@@ -904,7 +921,7 @@ class HyperConnection(Module):
904
921
  init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
905
922
  init_alpha0[layer_index % num_residual_streams, :] = 1.
906
923
 
907
- self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
924
+ self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
908
925
 
909
926
  self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
910
927
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
@@ -973,7 +990,7 @@ class ShiftTokens(Module):
973
990
  splitted = x.split(feats_per_shift, dim = -1)
974
991
  segments_to_shift, rest = splitted[:segments], splitted[segments:]
975
992
  segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
976
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
993
+ x = cat((*segments_to_shift, *rest), dim = -1)
977
994
  return self.fn(x, **kwargs)
978
995
 
979
996
  class FoldAxially(Module):
@@ -1080,7 +1097,7 @@ class ConcatCombine(Module):
1080
1097
 
1081
1098
  def forward(self, x, prev_layers: list[Tensor]):
1082
1099
  skip = prev_layers[self.prev_layer_ind]
1083
- concatted_skip = torch.cat((skip, x), dim = -1)
1100
+ concatted_skip = cat((skip, x), dim = -1)
1084
1101
  return self.combine(concatted_skip)
1085
1102
 
1086
1103
  # feedforward
@@ -1189,12 +1206,10 @@ class Attention(Module):
1189
1206
  hybrid_fold_axial_dim: int | None = None,
1190
1207
  one_kv_head = False,
1191
1208
  kv_heads = None,
1192
- shared_kv = False,
1193
1209
  value_dim_head = None,
1194
1210
  dim_out = None,
1195
- tensor_product = False, # https://arxiv.org/abs/2208.06061
1196
1211
  add_zero_kv = False, # same as add_zero_attn in pytorch
1197
- rotary_embed_values = False,
1212
+ rotate_num_heads = None,
1198
1213
  data_dependent_alibi = False,
1199
1214
  data_dependent_alibi_per_row = False,
1200
1215
  data_dependent_alibi_per_row_dim_head = 8,
@@ -1205,12 +1220,15 @@ class Attention(Module):
1205
1220
  cope_talking_heads = False,
1206
1221
  softclamp_logits = False,
1207
1222
  logit_softclamp_value = 50.,
1208
- neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1209
- neutreno_alpha = 0.4,
1210
1223
  learned_value_residual_mix = False,
1211
- laser = False, # https://arxiv.org/abs/2411.03493v1
1224
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1212
1225
  laser_softclamp_value = 15.,
1213
1226
  qkv_receive_diff_residuals = False,
1227
+ use_latent_q = False,
1228
+ dim_latent_q = None,
1229
+ use_latent_kv = False,
1230
+ dim_latent_kv = None,
1231
+ latent_rope_subheads = None,
1214
1232
  onnxable = False,
1215
1233
  attend_sdp_kwargs: dict = dict(
1216
1234
  enable_flash = True,
@@ -1242,13 +1260,51 @@ class Attention(Module):
1242
1260
  v_dim = value_dim_head * kv_heads
1243
1261
  out_dim = value_dim_head * heads
1244
1262
 
1245
- self.to_q = LinearNoBias(dim, q_dim)
1246
- self.to_k = LinearNoBias(dim_kv, k_dim)
1263
+ # determine input dimensions to qkv based on whether intermediate latent q and kv are being used
1264
+ # for eventually supporting multi-latent attention (MLA)
1265
+
1266
+ self.to_latent_q = None
1267
+ self.to_latent_kv = None
1268
+ self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
1269
+
1270
+ dim_q_input = dim
1271
+ dim_kv_input = dim_kv
1272
+
1273
+ if use_latent_q:
1274
+ assert exists(dim_latent_q)
1275
+ self.to_latent_q = LinearNoBias(dim, dim_latent_q)
1276
+ dim_q_input = dim_latent_q
1277
+
1278
+ if use_latent_kv:
1279
+ assert exists(dim_latent_kv)
1280
+ self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
1281
+ dim_kv_input = dim_latent_kv
1282
+
1283
+ if exists(latent_rope_subheads):
1284
+ assert not exists(rotate_num_heads)
1285
+ rotate_num_heads = latent_rope_subheads
1286
+
1287
+ k_dim = dim_head * (kv_heads - latent_rope_subheads)
1247
1288
 
1248
- # shared key / values, for further memory savings during inference
1289
+ self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
1290
+ self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
1249
1291
 
1250
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1251
- self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1292
+ self.use_latent_q = use_latent_q
1293
+ self.use_latent_kv = use_latent_kv
1294
+
1295
+ # query key projection
1296
+
1297
+ self.to_q = LinearNoBias(dim_q_input, q_dim)
1298
+ self.to_k = LinearNoBias(dim_kv_input, k_dim)
1299
+ self.to_v = LinearNoBias(dim_kv_input, v_dim)
1300
+
1301
+ # split and merge of attention heads
1302
+
1303
+ self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
1304
+ self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
1305
+ self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
1306
+
1307
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
1252
1308
 
1253
1309
  # whether qkv receives different residual stream combinations from hyper connections
1254
1310
 
@@ -1259,15 +1315,6 @@ class Attention(Module):
1259
1315
  self.laser = laser
1260
1316
  self.laser_softclamp_value = laser_softclamp_value
1261
1317
 
1262
- # relations projection from tp-attention
1263
-
1264
- self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
1265
-
1266
- # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
1267
-
1268
- self.neutreno_value_residual = neutreno_value_residual
1269
- self.neutreno_alpha = neutreno_alpha
1270
-
1271
1318
  # add GLU gating for aggregated values, from alphafold2
1272
1319
 
1273
1320
  self.to_v_gate = None
@@ -1393,12 +1440,22 @@ class Attention(Module):
1393
1440
 
1394
1441
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1395
1442
 
1443
+ hybrid_mix = None
1444
+ hybrid_norms = None
1396
1445
  hybrid_module = maybe(deepcopy)(hybrid_module)
1397
1446
 
1398
1447
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1399
1448
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
+ hybrid_mix = LinearNoBias(dim, heads)
1450
+
1451
+ hybrid_norms = ModuleList([
1452
+ MultiheadRMSNorm(dim_head, heads = heads),
1453
+ MultiheadRMSNorm(dim_head, heads = heads)
1454
+ ])
1400
1455
 
1401
1456
  self.hybrid_module = hybrid_module
1457
+ self.hybrid_norms = hybrid_norms
1458
+ self.hybrid_mix = hybrid_mix
1402
1459
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1403
1460
 
1404
1461
  # output dimension by default same as input, but can be overridden
@@ -1406,9 +1463,15 @@ class Attention(Module):
1406
1463
  dim_out = default(dim_out, dim)
1407
1464
  self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1408
1465
 
1409
- # whether to rotate positions into values, for absolute positions in addition to relative
1466
+ # the number of attention heads to rotate, for decoupled rope in multi-latent attention
1467
+
1468
+ rotate_num_heads = default(rotate_num_heads, heads)
1410
1469
 
1411
- self.rotary_embed_values = rotary_embed_values
1470
+ assert 0 < rotate_num_heads <= heads
1471
+ is_partial_rotate_heads = rotate_num_heads < heads
1472
+ assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
1473
+
1474
+ self.rotate_num_heads = rotate_num_heads
1412
1475
 
1413
1476
  # whether parent can kv cache
1414
1477
 
@@ -1438,47 +1501,79 @@ class Attention(Module):
1438
1501
  cache: Intermediates | None = None,
1439
1502
  value_residual = None
1440
1503
  ):
1441
- b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
1504
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1505
+
1506
+ # an interesting possibility with hyper connections
1507
+ # having queries, keys, values be routed from different layers
1442
1508
 
1443
1509
  assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
1444
1510
 
1445
1511
  if qkv_receive_diff_residuals:
1446
- assert not exists(self.to_r)
1512
+ assert x.ndim == 4 and x.shape[0] == 3
1447
1513
 
1448
1514
  q_input, k_input, v_input = x
1449
1515
  else:
1450
1516
  kv_input = default(context, x)
1451
-
1452
- q_input = x
1453
- k_input = kv_input
1454
- v_input = kv_input
1455
- r_input = x
1517
+ q_input, k_input, v_input = x, kv_input, kv_input
1456
1518
 
1457
1519
  if exists(mem):
1458
1520
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1459
1521
  v_input, _ = pack([mem, v_input], 'b * d')
1460
1522
 
1523
+ # multi-latent attention logic
1524
+ # https://arxiv.org/abs/2405.04434 - Deepseek-AI team
1525
+
1526
+ k_sub_heads = None # the rotateable subheads of keys derived from base sequence
1527
+
1528
+ if self.use_latent_q:
1529
+ q_input = self.to_latent_q(q_input)
1530
+
1531
+ if is_multi_latent_attn:
1532
+ assert not qkv_receive_diff_residuals
1533
+ needs_k_sub_heads = exists(self.to_rotateable_k)
1534
+
1535
+ latent_kv_input = self.to_latent_kv(k_input)
1536
+
1537
+ if needs_k_sub_heads:
1538
+ rotateable_k = self.to_rotateable_k(k_input)
1539
+ k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
1540
+
1541
+ if exists(cache):
1542
+ cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
1543
+ latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
1544
+
1545
+ if exists(maybe_cached_k_sub_heads):
1546
+ k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
1547
+
1548
+ if return_intermediates:
1549
+ cached_kv = (latent_kv_input, k_sub_heads)
1550
+
1551
+ k_input = v_input = latent_kv_input
1552
+
1553
+ # query, key, value projection
1554
+
1461
1555
  q = self.to_q(q_input)
1462
1556
  k = self.to_k(k_input)
1463
- v = self.to_v(v_input) if exists(self.to_v) else k
1464
- r = self.to_r(r_input) if exists(self.to_r) else None
1557
+ v = self.to_v(v_input)
1558
+
1559
+ q = self.split_q_heads(q)
1560
+ k = self.split_k_heads(k)
1561
+ v = self.split_v_heads(v)
1465
1562
 
1466
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1563
+ # take care of decoupled rope from multi-latent attention
1467
1564
 
1468
- k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1565
+ if exists(k_sub_heads):
1566
+ k = cat((k, k_sub_heads), dim = 1)
1469
1567
 
1470
- # if previous values passed in for residual, either invoke resformer or neutreno
1568
+ # if previous values passed in for residual, either invoke resformer
1471
1569
 
1472
1570
  orig_values = v
1473
1571
 
1572
+ # https://arxiv.org/abs/2410.17897v1
1573
+
1474
1574
  if exists(value_residual):
1475
- if self.neutreno_value_residual:
1476
- diff_values = (value_residual - v) * self.neutreno_alpha
1477
- diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1478
- else:
1479
- # https://arxiv.org/abs/2410.17897v1
1480
- value_residual_mix = self.to_value_residual_mix(q_input)
1481
- v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1575
+ value_residual_mix = self.to_value_residual_mix(q_input)
1576
+ v = value_residual.lerp(v, value_residual_mix)
1482
1577
 
1483
1578
  # qk normalization
1484
1579
 
@@ -1492,28 +1587,36 @@ class Attention(Module):
1492
1587
 
1493
1588
  # take care of caching
1494
1589
 
1495
- if exists(cache):
1496
- ck, cv = cache.cached_kv
1590
+ if not is_multi_latent_attn:
1591
+ if exists(cache):
1592
+ ck, cv = cache.cached_kv
1497
1593
 
1498
- if exists(mem):
1499
- mk, k = unpack(k, mem_packed_shape, 'b h * d')
1500
- mv, v = unpack(v, mem_packed_shape, 'b h * d')
1594
+ if exists(mem):
1595
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1596
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1501
1597
 
1502
- k = torch.cat((ck, k), dim = -2)
1503
- v = torch.cat((cv, v), dim = -2)
1598
+ k = cat((ck, k), dim = -2)
1599
+ v = cat((cv, v), dim = -2)
1504
1600
 
1505
- if exists(mem):
1506
- k = torch.cat((mk, k), dim = -2)
1507
- v = torch.cat((mv, v), dim = -2)
1601
+ if exists(mem):
1602
+ k = cat((mk, k), dim = -2)
1603
+ v = cat((mv, v), dim = -2)
1508
1604
 
1509
- if return_intermediates:
1510
- mem_len = mem.shape[-2] if exists(mem) else 0
1511
- cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1605
+ if return_intermediates:
1606
+ mem_len = mem.shape[-2] if exists(mem) else 0
1607
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1512
1608
 
1513
1609
  if exists(rotary_pos_emb):
1610
+ rotate_num_heads = self.rotate_num_heads
1611
+ partial_rotate_heads = rotate_num_heads < h
1612
+
1514
1613
  freqs, xpos_scale = rotary_pos_emb
1515
1614
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1516
1615
 
1616
+ if partial_rotate_heads:
1617
+ q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
1618
+ k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
1619
+
1517
1620
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1518
1621
 
1519
1622
  if has_context:
@@ -1524,8 +1627,9 @@ class Attention(Module):
1524
1627
 
1525
1628
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1526
1629
 
1527
- if self.rotary_embed_values:
1528
- v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1630
+ if partial_rotate_heads:
1631
+ q = cat((q_rest, q), dim = 1)
1632
+ k = cat((k_rest, k), dim = 1)
1529
1633
 
1530
1634
  input_mask = context_mask
1531
1635
 
@@ -1540,7 +1644,7 @@ class Attention(Module):
1540
1644
  elif not exists(input_mask):
1541
1645
  input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
1542
1646
  else:
1543
- input_mask = torch.cat((mem_mask, input_mask), dim = -1)
1647
+ input_mask = cat((mem_mask, input_mask), dim = -1)
1544
1648
 
1545
1649
  # i, j determined for relative positional bias, excluding memory key / values
1546
1650
 
@@ -1555,8 +1659,8 @@ class Attention(Module):
1555
1659
  mem_k = l2norm(mem_k)
1556
1660
  mem_k = mem_k * self.qk_norm_k_scale
1557
1661
 
1558
- k = torch.cat((mem_k, k), dim = -2)
1559
- v = torch.cat((mem_v, v), dim = -2)
1662
+ k = cat((mem_k, k), dim = -2)
1663
+ v = cat((mem_v, v), dim = -2)
1560
1664
 
1561
1665
  if exists(input_mask):
1562
1666
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
@@ -1580,8 +1684,8 @@ class Attention(Module):
1580
1684
  masks.append(~attn_mask)
1581
1685
 
1582
1686
  if exists(self.max_attend_past):
1583
- range_q = torch.arange(j - i, j, device = device)
1584
- range_k = torch.arange(j, device = device)
1687
+ range_q = arange(j - i, j, device = device)
1688
+ range_k = arange(j, device = device)
1585
1689
  dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
1586
1690
  max_attend_past_mask = dist > self.max_attend_past
1587
1691
  max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
@@ -1629,18 +1733,10 @@ class Attention(Module):
1629
1733
  if self.laser:
1630
1734
  out = log(out)
1631
1735
 
1632
- # store the values for resformer or Neutreno
1736
+ # store the values for resformer
1633
1737
 
1634
1738
  intermediates.values = orig_values
1635
1739
 
1636
- if exists(value_residual) and self.neutreno_value_residual:
1637
- out = out + diff_values
1638
-
1639
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1640
-
1641
- if exists(r):
1642
- out = out * r + out
1643
-
1644
1740
  # normformer scaling of heads
1645
1741
 
1646
1742
  if head_scale:
@@ -1652,11 +1748,9 @@ class Attention(Module):
1652
1748
  head_gate = self.to_v_head_gate(x)
1653
1749
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1654
1750
 
1655
- # merge heads
1656
-
1657
- out = rearrange(out, 'b h n d -> b n (h d)')
1751
+ # if exists hybrid module, must do a normalization
1658
1752
 
1659
- # hybrid module
1753
+ # hybrid module
1660
1754
 
1661
1755
  if exists(self.hybrid_module):
1662
1756
 
@@ -1674,8 +1768,23 @@ class Attention(Module):
1674
1768
  # handle hybrid out
1675
1769
 
1676
1770
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1771
+
1772
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1773
+
1774
+ if hybrid_out.ndim == 3:
1775
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1776
+
1777
+ out_norm, hybrid_out_norm = self.hybrid_norms
1778
+
1779
+ out = out_norm(out)
1780
+ hybrid_out = hybrid_out_norm(hybrid_out)
1781
+
1677
1782
  out = 0.5 * (out + hybrid_out)
1678
1783
 
1784
+ # merge heads
1785
+
1786
+ out = self.merge_heads(out)
1787
+
1679
1788
  # alphafold2 styled gating of the values
1680
1789
 
1681
1790
  if exists(self.to_v_gate):
@@ -1747,8 +1856,6 @@ class AttentionLayers(Module):
1747
1856
  sandwich_norm = False,
1748
1857
  softclamp_output = False,
1749
1858
  softclamp_output_value = 30.,
1750
- resi_dual = False,
1751
- resi_dual_scale = 1.,
1752
1859
  zero_init_branch_output = False,
1753
1860
  layer_dropout = 0.,
1754
1861
  cross_attn_tokens_dropout = 0.,
@@ -1775,12 +1882,9 @@ class AttentionLayers(Module):
1775
1882
 
1776
1883
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1777
1884
  data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
1778
- neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
1779
1885
 
1780
1886
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1781
1887
 
1782
- add_value_residual |= neutreno_value_residual
1783
-
1784
1888
  self.dim = dim
1785
1889
  self.causal = causal
1786
1890
  self.layers = ModuleList([])
@@ -1831,19 +1935,11 @@ class AttentionLayers(Module):
1831
1935
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1832
1936
  self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1833
1937
 
1834
- assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1835
1938
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1836
1939
 
1837
- if resi_dual:
1838
- pre_norm = False
1839
-
1840
1940
  self.pre_norm = pre_norm
1841
1941
  self.sandwich_norm = sandwich_norm
1842
1942
 
1843
- self.resi_dual = resi_dual
1844
- assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1845
- self.resi_dual_scale = resi_dual_scale
1846
-
1847
1943
  self.residual_attn = residual_attn
1848
1944
  self.cross_residual_attn = cross_residual_attn
1849
1945
  assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
@@ -2002,7 +2098,7 @@ class AttentionLayers(Module):
2002
2098
 
2003
2099
  # whether it has post norm
2004
2100
 
2005
- self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
2101
+ self.final_norm = norm_fn() if pre_norm else nn.Identity()
2006
2102
 
2007
2103
  # whether unet or not
2008
2104
 
@@ -2175,7 +2271,7 @@ class AttentionLayers(Module):
2175
2271
  # handle left padded sequences
2176
2272
 
2177
2273
  if exists(seq_start_pos):
2178
- seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
2274
+ seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
2179
2275
  left_pad_mask = seq_arange >= seq_start_pos[..., None]
2180
2276
 
2181
2277
  if exists(self_attn_kv_mask):
@@ -2193,7 +2289,7 @@ class AttentionLayers(Module):
2193
2289
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
2194
2290
 
2195
2291
  if not exists(pos):
2196
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2292
+ pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
2197
2293
 
2198
2294
  rotary_pos_emb = self.rotary_pos_emb(pos)
2199
2295
 
@@ -2213,7 +2309,7 @@ class AttentionLayers(Module):
2213
2309
  attn_cache = []
2214
2310
 
2215
2311
  if exists(cache):
2216
- assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
2312
+ assert self.causal and not any([*map(exists, (mask, attn_mask))])
2217
2313
 
2218
2314
  if exists(context):
2219
2315
  context = context[:, :0]
@@ -2231,13 +2327,7 @@ class AttentionLayers(Module):
2231
2327
  is_multistream = streams > 1
2232
2328
 
2233
2329
  if is_multistream:
2234
- x = repeat(x, 'b n d -> b n s d', s = streams)
2235
- x = x + self.stream_emb
2236
- x = rearrange(x, 'b n s d -> (b s) n d')
2237
-
2238
- # outer residual - for resiDual paper
2239
-
2240
- outer_residual = x * self.resi_dual_scale
2330
+ x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
2241
2331
 
2242
2332
  # get layers to be executed
2243
2333
 
@@ -2359,9 +2449,6 @@ class AttentionLayers(Module):
2359
2449
  if not exists(first_cross_attn_inter) and layer_type == 'c':
2360
2450
  first_cross_attn_inter = inter
2361
2451
 
2362
- if self.resi_dual:
2363
- outer_residual = outer_residual + out * self.resi_dual_scale
2364
-
2365
2452
  if exists(post_branch_norm):
2366
2453
  out = post_branch_norm(out)
2367
2454
 
@@ -2395,10 +2482,7 @@ class AttentionLayers(Module):
2395
2482
  if is_multistream:
2396
2483
  x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2397
2484
 
2398
- if self.resi_dual:
2399
- x = x + final_norm(outer_residual)
2400
- else:
2401
- x = final_norm(x)
2485
+ x = final_norm(x)
2402
2486
 
2403
2487
  if not return_hiddens:
2404
2488
  return x
@@ -2444,7 +2528,7 @@ class PrefixDecoder(AttentionLayers):
2444
2528
  if isinstance(prefix_attn_len, int):
2445
2529
  prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
2446
2530
 
2447
- prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2531
+ prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2448
2532
  forwarded_mask = forwarded_mask | prefix_mask
2449
2533
 
2450
2534
  if exists(attn_mask):
@@ -2773,13 +2857,13 @@ class TransformerWrapper(Module):
2773
2857
  prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2774
2858
  assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2775
2859
 
2776
- x = torch.cat((prepend_embeds, x), dim = -2)
2860
+ x = cat((prepend_embeds, x), dim = -2)
2777
2861
 
2778
2862
  if exists(prepend_mask) or exists(mask):
2779
2863
  mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
2780
2864
  prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
2781
2865
 
2782
- mask = torch.cat((prepend_mask, mask), dim = -1)
2866
+ mask = cat((prepend_mask, mask), dim = -1)
2783
2867
 
2784
2868
  # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2785
2869
 
@@ -2945,7 +3029,7 @@ class TransformerWrapper(Module):
2945
3029
 
2946
3030
  if return_mems:
2947
3031
  hiddens = intermediates.hiddens
2948
- new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
3032
+ new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
2949
3033
  new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
2950
3034
 
2951
3035
  if not return_intermediates: