x-transformers 1.44.8__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor
12
+ from torch import nn, einsum, Tensor, cat, stack, arange
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -18,14 +18,22 @@ from collections import namedtuple
18
18
  from contextlib import nullcontext
19
19
  from dataclasses import dataclass
20
20
 
21
+ from loguru import logger
22
+
23
+ from x_transformers.attend import Attend, Intermediates
24
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
+
21
26
  import einx
22
27
  from einops.layers.torch import Rearrange
23
28
  from einops import rearrange, repeat, reduce, pack, unpack
24
29
 
25
- from loguru import logger
30
+ # einstein notation
26
31
 
27
- from x_transformers.attend import Attend, Intermediates
28
- from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
32
+ # b - batch
33
+ # n - sequence
34
+ # d - feature dimension
35
+ # h - attention heads
36
+ # i, j - sequence (source, target)
29
37
 
30
38
  # constants
31
39
 
@@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout):
220
228
  num_keep = max(1, int(keep_prob * n))
221
229
  keep_indices = logits.topk(num_keep, dim = 1).indices
222
230
 
223
- batch_indices = torch.arange(b, device = device)
231
+ batch_indices = arange(b, device = device)
224
232
  batch_indices = rearrange(batch_indices, 'b -> b 1')
225
233
 
226
234
  seq = seq[batch_indices, keep_indices]
@@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout):
228
236
  if exists(mask):
229
237
  seq_counts = mask.sum(dim = -1)
230
238
  seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
231
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
239
+ keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
232
240
 
233
241
  mask = mask[batch_indices, keep_indices] & keep_mask
234
242
 
@@ -274,7 +282,7 @@ class AbsolutePositionalEmbedding(Module):
274
282
  assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
275
283
 
276
284
  if not exists(pos):
277
- pos = torch.arange(seq_len, device = device)
285
+ pos = arange(seq_len, device = device)
278
286
 
279
287
  if exists(seq_start_pos):
280
288
  pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
@@ -290,7 +298,7 @@ class ScaledSinusoidalEmbedding(Module):
290
298
  self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
291
299
 
292
300
  half_dim = dim // 2
293
- freq_seq = torch.arange(half_dim).float() / half_dim
301
+ freq_seq = arange(half_dim).float() / half_dim
294
302
  inv_freq = theta ** -freq_seq
295
303
  self.register_buffer('inv_freq', inv_freq, persistent = False)
296
304
 
@@ -298,13 +306,13 @@ class ScaledSinusoidalEmbedding(Module):
298
306
  seq_len, device = x.shape[1], x.device
299
307
 
300
308
  if not exists(pos):
301
- pos = torch.arange(seq_len, device = device)
309
+ pos = arange(seq_len, device = device)
302
310
 
303
311
  if exists(seq_start_pos):
304
312
  pos = pos - seq_start_pos[..., None]
305
313
 
306
314
  emb = einsum('i, j -> i j', pos, self.inv_freq)
307
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
315
+ emb = cat((emb.sin(), emb.cos()), dim = -1)
308
316
  return emb * self.scale
309
317
 
310
318
  class RelativePositionBias(Module):
@@ -344,8 +352,8 @@ class RelativePositionBias(Module):
344
352
 
345
353
  def forward(self, i, j):
346
354
  device = self.device
347
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
348
- k_pos = torch.arange(j, dtype = torch.long, device = device)
355
+ q_pos = arange(j - i, j, dtype = torch.long, device = device)
356
+ k_pos = arange(j, dtype = torch.long, device = device)
349
357
  rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
350
358
  rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
351
359
  values = self.relative_attention_bias(rp_bucket)
@@ -376,7 +384,7 @@ class CoPE(Module):
376
384
  if not soft_onehot:
377
385
  return
378
386
 
379
- self.register_buffer('positions', torch.arange(max_pos))
387
+ self.register_buffer('positions', arange(max_pos))
380
388
 
381
389
  def forward(self, query, attn_logits):
382
390
 
@@ -445,13 +453,13 @@ class DynamicPositionBias(Module):
445
453
  n, device = j, self.device
446
454
 
447
455
  # get the (n x n) matrix of distances
448
- seq_arange = torch.arange(n, device = device)
449
- context_arange = torch.arange(n, device = device)
456
+ seq_arange = arange(n, device = device)
457
+ context_arange = arange(n, device = device)
450
458
  indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
451
459
  indices += (n - 1)
452
460
 
453
461
  # input to continuous positions MLP
454
- pos = torch.arange(-n + 1, n, device = device).float()
462
+ pos = arange(-n + 1, n, device = device).float()
455
463
  pos = rearrange(pos, '... -> ... 1')
456
464
 
457
465
  if self.log_distance:
@@ -525,8 +533,8 @@ class AlibiPositionalBias(Module):
525
533
  if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
526
534
  return self.bias[..., -i:, -j:]
527
535
 
528
- seq_arange = torch.arange(j - i, j, device = device)
529
- context_arange = torch.arange(j, device = device)
536
+ seq_arange = arange(j - i, j, device = device)
537
+ context_arange = arange(j, device = device)
530
538
  bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
531
539
 
532
540
  bias = bias * self.slopes
@@ -642,7 +650,7 @@ class RotaryEmbedding(Module):
642
650
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
643
651
  base *= base_rescale_factor ** (dim / (dim - 2))
644
652
 
645
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
653
+ inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
646
654
  self.register_buffer('inv_freq', inv_freq)
647
655
 
648
656
  assert interpolation_factor >= 1.
@@ -652,7 +660,7 @@ class RotaryEmbedding(Module):
652
660
  self.register_buffer('scale', None)
653
661
  return
654
662
 
655
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
663
+ scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
656
664
 
657
665
  self.scale_base = scale_base
658
666
  self.register_buffer('scale', scale)
@@ -660,7 +668,7 @@ class RotaryEmbedding(Module):
660
668
  def forward_from_seq_len(self, seq_len):
661
669
  device = self.inv_freq.device
662
670
 
663
- t = torch.arange(seq_len, device = device)
671
+ t = arange(seq_len, device = device)
664
672
  return self.forward(t)
665
673
 
666
674
  @autocast('cuda', enabled = False)
@@ -671,7 +679,7 @@ class RotaryEmbedding(Module):
671
679
  t = rearrange(t, 'n -> 1 n')
672
680
 
673
681
  freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
674
- freqs = torch.stack((freqs, freqs), dim = -1)
682
+ freqs = stack((freqs, freqs), dim = -1)
675
683
  freqs = rearrange(freqs, '... d r -> ... (d r)')
676
684
 
677
685
  if not exists(self.scale):
@@ -679,7 +687,7 @@ class RotaryEmbedding(Module):
679
687
 
680
688
  power = (t - (max_pos // 2)) / self.scale_base
681
689
  scale = self.scale ** rearrange(power, '... n -> ... n 1')
682
- scale = torch.stack((scale, scale), dim = -1)
690
+ scale = stack((scale, scale), dim = -1)
683
691
  scale = rearrange(scale, '... d r -> ... (d r)')
684
692
 
685
693
  return freqs, scale
@@ -687,7 +695,7 @@ class RotaryEmbedding(Module):
687
695
  def rotate_half(x):
688
696
  x = rearrange(x, '... (d r) -> ... d r', r = 2)
689
697
  x1, x2 = x.unbind(dim = -1)
690
- x = torch.stack((-x2, x1), dim = -1)
698
+ x = stack((-x2, x1), dim = -1)
691
699
  return rearrange(x, '... d r -> ... (d r)')
692
700
 
693
701
  @autocast('cuda', enabled = False)
@@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
703
711
  # partial rotary embeddings, Wang et al. GPT-J
704
712
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
705
713
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
706
- out = torch.cat((t, t_unrotated), dim = -1)
714
+ out = cat((t, t_unrotated), dim = -1)
707
715
 
708
716
  return out.type(orig_dtype)
709
717
 
@@ -833,6 +841,15 @@ class SimpleRMSNorm(Module):
833
841
  def forward(self, x):
834
842
  return F.normalize(x, dim = -1) * self.scale
835
843
 
844
+ class MultiheadRMSNorm(Module):
845
+ def __init__(self, dim, heads):
846
+ super().__init__()
847
+ self.rmsnorm = SimpleRMSNorm(dim)
848
+ self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
849
+
850
+ def forward(self, x):
851
+ return self.rmsnorm(x) * (self.gamma + 1.)
852
+
836
853
  # residual and residual gates
837
854
 
838
855
  class Residual(Module):
@@ -904,7 +921,7 @@ class HyperConnection(Module):
904
921
  init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
905
922
  init_alpha0[layer_index % num_residual_streams, :] = 1.
906
923
 
907
- self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
924
+ self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
908
925
 
909
926
  self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
910
927
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
@@ -973,7 +990,7 @@ class ShiftTokens(Module):
973
990
  splitted = x.split(feats_per_shift, dim = -1)
974
991
  segments_to_shift, rest = splitted[:segments], splitted[segments:]
975
992
  segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
976
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
993
+ x = cat((*segments_to_shift, *rest), dim = -1)
977
994
  return self.fn(x, **kwargs)
978
995
 
979
996
  class FoldAxially(Module):
@@ -1080,7 +1097,7 @@ class ConcatCombine(Module):
1080
1097
 
1081
1098
  def forward(self, x, prev_layers: list[Tensor]):
1082
1099
  skip = prev_layers[self.prev_layer_ind]
1083
- concatted_skip = torch.cat((skip, x), dim = -1)
1100
+ concatted_skip = cat((skip, x), dim = -1)
1084
1101
  return self.combine(concatted_skip)
1085
1102
 
1086
1103
  # feedforward
@@ -1189,12 +1206,10 @@ class Attention(Module):
1189
1206
  hybrid_fold_axial_dim: int | None = None,
1190
1207
  one_kv_head = False,
1191
1208
  kv_heads = None,
1192
- shared_kv = False,
1193
1209
  value_dim_head = None,
1194
1210
  dim_out = None,
1195
- tensor_product = False, # https://arxiv.org/abs/2208.06061
1196
1211
  add_zero_kv = False, # same as add_zero_attn in pytorch
1197
- rotary_embed_values = False,
1212
+ rotate_num_heads = None,
1198
1213
  data_dependent_alibi = False,
1199
1214
  data_dependent_alibi_per_row = False,
1200
1215
  data_dependent_alibi_per_row_dim_head = 8,
@@ -1205,12 +1220,15 @@ class Attention(Module):
1205
1220
  cope_talking_heads = False,
1206
1221
  softclamp_logits = False,
1207
1222
  logit_softclamp_value = 50.,
1208
- neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1209
- neutreno_alpha = 0.4,
1210
1223
  learned_value_residual_mix = False,
1211
- laser = False, # https://arxiv.org/abs/2411.03493v1
1224
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1212
1225
  laser_softclamp_value = 15.,
1213
1226
  qkv_receive_diff_residuals = False,
1227
+ use_latent_q = False,
1228
+ dim_latent_q = None,
1229
+ use_latent_kv = False,
1230
+ dim_latent_kv = None,
1231
+ latent_rope_subheads = None,
1214
1232
  onnxable = False,
1215
1233
  attend_sdp_kwargs: dict = dict(
1216
1234
  enable_flash = True,
@@ -1242,13 +1260,51 @@ class Attention(Module):
1242
1260
  v_dim = value_dim_head * kv_heads
1243
1261
  out_dim = value_dim_head * heads
1244
1262
 
1245
- self.to_q = LinearNoBias(dim, q_dim)
1246
- self.to_k = LinearNoBias(dim_kv, k_dim)
1263
+ # determine input dimensions to qkv based on whether intermediate latent q and kv are being used
1264
+ # for eventually supporting multi-latent attention (MLA)
1265
+
1266
+ self.to_latent_q = None
1267
+ self.to_latent_kv = None
1268
+ self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
1269
+
1270
+ dim_q_input = dim
1271
+ dim_kv_input = dim_kv
1272
+
1273
+ if use_latent_q:
1274
+ assert exists(dim_latent_q)
1275
+ self.to_latent_q = LinearNoBias(dim, dim_latent_q)
1276
+ dim_q_input = dim_latent_q
1277
+
1278
+ if use_latent_kv:
1279
+ assert exists(dim_latent_kv)
1280
+ self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
1281
+ dim_kv_input = dim_latent_kv
1282
+
1283
+ if exists(latent_rope_subheads):
1284
+ assert not exists(rotate_num_heads)
1285
+ rotate_num_heads = latent_rope_subheads
1286
+
1287
+ k_dim = dim_head * (kv_heads - latent_rope_subheads)
1247
1288
 
1248
- # shared key / values, for further memory savings during inference
1289
+ self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
1290
+ self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
1249
1291
 
1250
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1251
- self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1292
+ self.use_latent_q = use_latent_q
1293
+ self.use_latent_kv = use_latent_kv
1294
+
1295
+ # query key projection
1296
+
1297
+ self.to_q = LinearNoBias(dim_q_input, q_dim)
1298
+ self.to_k = LinearNoBias(dim_kv_input, k_dim)
1299
+ self.to_v = LinearNoBias(dim_kv_input, v_dim)
1300
+
1301
+ # split and merge of attention heads
1302
+
1303
+ self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
1304
+ self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
1305
+ self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
1306
+
1307
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
1252
1308
 
1253
1309
  # whether qkv receives different residual stream combinations from hyper connections
1254
1310
 
@@ -1259,15 +1315,6 @@ class Attention(Module):
1259
1315
  self.laser = laser
1260
1316
  self.laser_softclamp_value = laser_softclamp_value
1261
1317
 
1262
- # relations projection from tp-attention
1263
-
1264
- self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
1265
-
1266
- # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
1267
-
1268
- self.neutreno_value_residual = neutreno_value_residual
1269
- self.neutreno_alpha = neutreno_alpha
1270
-
1271
1318
  # add GLU gating for aggregated values, from alphafold2
1272
1319
 
1273
1320
  self.to_v_gate = None
@@ -1393,12 +1440,22 @@ class Attention(Module):
1393
1440
 
1394
1441
  # hybrid module, in same vein as hymba https://www.arxiv.org/abs/2411.13676
1395
1442
 
1443
+ hybrid_mix = None
1444
+ hybrid_norms = None
1396
1445
  hybrid_module = maybe(deepcopy)(hybrid_module)
1397
1446
 
1398
1447
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1399
1448
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
+ hybrid_mix = LinearNoBias(dim, heads)
1450
+
1451
+ hybrid_norms = ModuleList([
1452
+ MultiheadRMSNorm(dim_head, heads = heads),
1453
+ MultiheadRMSNorm(dim_head, heads = heads)
1454
+ ])
1400
1455
 
1401
1456
  self.hybrid_module = hybrid_module
1457
+ self.hybrid_norms = hybrid_norms
1458
+ self.hybrid_mix = hybrid_mix
1402
1459
  self.hybrid_mask_kwarg = hybrid_mask_kwarg # for bidirectional, can forward `mask` into the hybrid module and let it handle variable lengths
1403
1460
 
1404
1461
  # output dimension by default same as input, but can be overridden
@@ -1406,9 +1463,15 @@ class Attention(Module):
1406
1463
  dim_out = default(dim_out, dim)
1407
1464
  self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1408
1465
 
1409
- # whether to rotate positions into values, for absolute positions in addition to relative
1466
+ # the number of attention heads to rotate, for decoupled rope in multi-latent attention
1467
+
1468
+ rotate_num_heads = default(rotate_num_heads, heads)
1410
1469
 
1411
- self.rotary_embed_values = rotary_embed_values
1470
+ assert 0 < rotate_num_heads <= heads
1471
+ is_partial_rotate_heads = rotate_num_heads < heads
1472
+ assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
1473
+
1474
+ self.rotate_num_heads = rotate_num_heads
1412
1475
 
1413
1476
  # whether parent can kv cache
1414
1477
 
@@ -1438,47 +1501,79 @@ class Attention(Module):
1438
1501
  cache: Intermediates | None = None,
1439
1502
  value_residual = None
1440
1503
  ):
1441
- b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
1504
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1505
+
1506
+ # an interesting possibility with hyper connections
1507
+ # having queries, keys, values be routed from different layers
1442
1508
 
1443
1509
  assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
1444
1510
 
1445
1511
  if qkv_receive_diff_residuals:
1446
- assert not exists(self.to_r)
1512
+ assert x.ndim == 4 and x.shape[0] == 3
1447
1513
 
1448
1514
  q_input, k_input, v_input = x
1449
1515
  else:
1450
1516
  kv_input = default(context, x)
1451
-
1452
- q_input = x
1453
- k_input = kv_input
1454
- v_input = kv_input
1455
- r_input = x
1517
+ q_input, k_input, v_input = x, kv_input, kv_input
1456
1518
 
1457
1519
  if exists(mem):
1458
1520
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1459
1521
  v_input, _ = pack([mem, v_input], 'b * d')
1460
1522
 
1523
+ # multi-latent attention logic
1524
+ # https://arxiv.org/abs/2405.04434 - Deepseek-AI team
1525
+
1526
+ k_sub_heads = None # the rotateable subheads of keys derived from base sequence
1527
+
1528
+ if self.use_latent_q:
1529
+ q_input = self.to_latent_q(q_input)
1530
+
1531
+ if is_multi_latent_attn:
1532
+ assert not qkv_receive_diff_residuals
1533
+ needs_k_sub_heads = exists(self.to_rotateable_k)
1534
+
1535
+ latent_kv_input = self.to_latent_kv(k_input)
1536
+
1537
+ if needs_k_sub_heads:
1538
+ rotateable_k = self.to_rotateable_k(k_input)
1539
+ k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
1540
+
1541
+ if exists(cache):
1542
+ cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
1543
+ latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
1544
+
1545
+ if exists(maybe_cached_k_sub_heads):
1546
+ k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
1547
+
1548
+ if return_intermediates:
1549
+ cached_kv = (latent_kv_input, k_sub_heads)
1550
+
1551
+ k_input = v_input = latent_kv_input
1552
+
1553
+ # query, key, value projection
1554
+
1461
1555
  q = self.to_q(q_input)
1462
1556
  k = self.to_k(k_input)
1463
- v = self.to_v(v_input) if exists(self.to_v) else k
1464
- r = self.to_r(r_input) if exists(self.to_r) else None
1557
+ v = self.to_v(v_input)
1558
+
1559
+ q = self.split_q_heads(q)
1560
+ k = self.split_k_heads(k)
1561
+ v = self.split_v_heads(v)
1465
1562
 
1466
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1563
+ # take care of decoupled rope from multi-latent attention
1467
1564
 
1468
- k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1565
+ if exists(k_sub_heads):
1566
+ k = cat((k, k_sub_heads), dim = 1)
1469
1567
 
1470
- # if previous values passed in for residual, either invoke resformer or neutreno
1568
+ # if previous values passed in for residual, either invoke resformer
1471
1569
 
1472
1570
  orig_values = v
1473
1571
 
1572
+ # https://arxiv.org/abs/2410.17897v1
1573
+
1474
1574
  if exists(value_residual):
1475
- if self.neutreno_value_residual:
1476
- diff_values = (value_residual - v) * self.neutreno_alpha
1477
- diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1478
- else:
1479
- # https://arxiv.org/abs/2410.17897v1
1480
- value_residual_mix = self.to_value_residual_mix(q_input)
1481
- v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1575
+ value_residual_mix = self.to_value_residual_mix(q_input)
1576
+ v = value_residual.lerp(v, value_residual_mix)
1482
1577
 
1483
1578
  # qk normalization
1484
1579
 
@@ -1492,28 +1587,36 @@ class Attention(Module):
1492
1587
 
1493
1588
  # take care of caching
1494
1589
 
1495
- if exists(cache):
1496
- ck, cv = cache.cached_kv
1590
+ if not is_multi_latent_attn:
1591
+ if exists(cache):
1592
+ ck, cv = cache.cached_kv
1497
1593
 
1498
- if exists(mem):
1499
- mk, k = unpack(k, mem_packed_shape, 'b h * d')
1500
- mv, v = unpack(v, mem_packed_shape, 'b h * d')
1594
+ if exists(mem):
1595
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1596
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1501
1597
 
1502
- k = torch.cat((ck, k), dim = -2)
1503
- v = torch.cat((cv, v), dim = -2)
1598
+ k = cat((ck, k), dim = -2)
1599
+ v = cat((cv, v), dim = -2)
1504
1600
 
1505
- if exists(mem):
1506
- k = torch.cat((mk, k), dim = -2)
1507
- v = torch.cat((mv, v), dim = -2)
1601
+ if exists(mem):
1602
+ k = cat((mk, k), dim = -2)
1603
+ v = cat((mv, v), dim = -2)
1508
1604
 
1509
- if return_intermediates:
1510
- mem_len = mem.shape[-2] if exists(mem) else 0
1511
- cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1605
+ if return_intermediates:
1606
+ mem_len = mem.shape[-2] if exists(mem) else 0
1607
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1512
1608
 
1513
1609
  if exists(rotary_pos_emb):
1610
+ rotate_num_heads = self.rotate_num_heads
1611
+ partial_rotate_heads = rotate_num_heads < h
1612
+
1514
1613
  freqs, xpos_scale = rotary_pos_emb
1515
1614
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1516
1615
 
1616
+ if partial_rotate_heads:
1617
+ q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
1618
+ k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
1619
+
1517
1620
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1518
1621
 
1519
1622
  if has_context:
@@ -1524,8 +1627,9 @@ class Attention(Module):
1524
1627
 
1525
1628
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1526
1629
 
1527
- if self.rotary_embed_values:
1528
- v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1630
+ if partial_rotate_heads:
1631
+ q = cat((q_rest, q), dim = 1)
1632
+ k = cat((k_rest, k), dim = 1)
1529
1633
 
1530
1634
  input_mask = context_mask
1531
1635
 
@@ -1540,7 +1644,7 @@ class Attention(Module):
1540
1644
  elif not exists(input_mask):
1541
1645
  input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
1542
1646
  else:
1543
- input_mask = torch.cat((mem_mask, input_mask), dim = -1)
1647
+ input_mask = cat((mem_mask, input_mask), dim = -1)
1544
1648
 
1545
1649
  # i, j determined for relative positional bias, excluding memory key / values
1546
1650
 
@@ -1555,8 +1659,8 @@ class Attention(Module):
1555
1659
  mem_k = l2norm(mem_k)
1556
1660
  mem_k = mem_k * self.qk_norm_k_scale
1557
1661
 
1558
- k = torch.cat((mem_k, k), dim = -2)
1559
- v = torch.cat((mem_v, v), dim = -2)
1662
+ k = cat((mem_k, k), dim = -2)
1663
+ v = cat((mem_v, v), dim = -2)
1560
1664
 
1561
1665
  if exists(input_mask):
1562
1666
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
@@ -1580,8 +1684,8 @@ class Attention(Module):
1580
1684
  masks.append(~attn_mask)
1581
1685
 
1582
1686
  if exists(self.max_attend_past):
1583
- range_q = torch.arange(j - i, j, device = device)
1584
- range_k = torch.arange(j, device = device)
1687
+ range_q = arange(j - i, j, device = device)
1688
+ range_k = arange(j, device = device)
1585
1689
  dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
1586
1690
  max_attend_past_mask = dist > self.max_attend_past
1587
1691
  max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
@@ -1629,18 +1733,10 @@ class Attention(Module):
1629
1733
  if self.laser:
1630
1734
  out = log(out)
1631
1735
 
1632
- # store the values for resformer or Neutreno
1736
+ # store the values for resformer
1633
1737
 
1634
1738
  intermediates.values = orig_values
1635
1739
 
1636
- if exists(value_residual) and self.neutreno_value_residual:
1637
- out = out + diff_values
1638
-
1639
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1640
-
1641
- if exists(r):
1642
- out = out * r + out
1643
-
1644
1740
  # normformer scaling of heads
1645
1741
 
1646
1742
  if head_scale:
@@ -1652,11 +1748,9 @@ class Attention(Module):
1652
1748
  head_gate = self.to_v_head_gate(x)
1653
1749
  out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)
1654
1750
 
1655
- # merge heads
1656
-
1657
- out = rearrange(out, 'b h n d -> b n (h d)')
1751
+ # if exists hybrid module, must do a normalization
1658
1752
 
1659
- # hybrid module
1753
+ # hybrid module
1660
1754
 
1661
1755
  if exists(self.hybrid_module):
1662
1756
 
@@ -1674,8 +1768,23 @@ class Attention(Module):
1674
1768
  # handle hybrid out
1675
1769
 
1676
1770
  (hybrid_out, *rest_hybrid_outs), _ = tree_flatten(hybrid_outputs)
1771
+
1772
+ # handle variable hybrid output and multi rmsnorm before summing to main attention output (also normed)
1773
+
1774
+ if hybrid_out.ndim == 3:
1775
+ hybrid_out = rearrange(hybrid_out, 'b n (h d) -> b h n d', h = h)
1776
+
1777
+ out_norm, hybrid_out_norm = self.hybrid_norms
1778
+
1779
+ out = out_norm(out)
1780
+ hybrid_out = hybrid_out_norm(hybrid_out)
1781
+
1677
1782
  out = 0.5 * (out + hybrid_out)
1678
1783
 
1784
+ # merge heads
1785
+
1786
+ out = self.merge_heads(out)
1787
+
1679
1788
  # alphafold2 styled gating of the values
1680
1789
 
1681
1790
  if exists(self.to_v_gate):
@@ -1747,8 +1856,6 @@ class AttentionLayers(Module):
1747
1856
  sandwich_norm = False,
1748
1857
  softclamp_output = False,
1749
1858
  softclamp_output_value = 30.,
1750
- resi_dual = False,
1751
- resi_dual_scale = 1.,
1752
1859
  zero_init_branch_output = False,
1753
1860
  layer_dropout = 0.,
1754
1861
  cross_attn_tokens_dropout = 0.,
@@ -1775,12 +1882,9 @@ class AttentionLayers(Module):
1775
1882
 
1776
1883
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1777
1884
  data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
1778
- neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
1779
1885
 
1780
1886
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1781
1887
 
1782
- add_value_residual |= neutreno_value_residual
1783
-
1784
1888
  self.dim = dim
1785
1889
  self.causal = causal
1786
1890
  self.layers = ModuleList([])
@@ -1831,19 +1935,11 @@ class AttentionLayers(Module):
1831
1935
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1832
1936
  self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1833
1937
 
1834
- assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1835
1938
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1836
1939
 
1837
- if resi_dual:
1838
- pre_norm = False
1839
-
1840
1940
  self.pre_norm = pre_norm
1841
1941
  self.sandwich_norm = sandwich_norm
1842
1942
 
1843
- self.resi_dual = resi_dual
1844
- assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1845
- self.resi_dual_scale = resi_dual_scale
1846
-
1847
1943
  self.residual_attn = residual_attn
1848
1944
  self.cross_residual_attn = cross_residual_attn
1849
1945
  assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
@@ -2002,7 +2098,7 @@ class AttentionLayers(Module):
2002
2098
 
2003
2099
  # whether it has post norm
2004
2100
 
2005
- self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
2101
+ self.final_norm = norm_fn() if pre_norm else nn.Identity()
2006
2102
 
2007
2103
  # whether unet or not
2008
2104
 
@@ -2175,7 +2271,7 @@ class AttentionLayers(Module):
2175
2271
  # handle left padded sequences
2176
2272
 
2177
2273
  if exists(seq_start_pos):
2178
- seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
2274
+ seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
2179
2275
  left_pad_mask = seq_arange >= seq_start_pos[..., None]
2180
2276
 
2181
2277
  if exists(self_attn_kv_mask):
@@ -2193,7 +2289,7 @@ class AttentionLayers(Module):
2193
2289
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
2194
2290
 
2195
2291
  if not exists(pos):
2196
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2292
+ pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
2197
2293
 
2198
2294
  rotary_pos_emb = self.rotary_pos_emb(pos)
2199
2295
 
@@ -2213,7 +2309,7 @@ class AttentionLayers(Module):
2213
2309
  attn_cache = []
2214
2310
 
2215
2311
  if exists(cache):
2216
- assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
2312
+ assert self.causal and not any([*map(exists, (mask, attn_mask))])
2217
2313
 
2218
2314
  if exists(context):
2219
2315
  context = context[:, :0]
@@ -2231,13 +2327,7 @@ class AttentionLayers(Module):
2231
2327
  is_multistream = streams > 1
2232
2328
 
2233
2329
  if is_multistream:
2234
- x = repeat(x, 'b n d -> b n s d', s = streams)
2235
- x = x + self.stream_emb
2236
- x = rearrange(x, 'b n s d -> (b s) n d')
2237
-
2238
- # outer residual - for resiDual paper
2239
-
2240
- outer_residual = x * self.resi_dual_scale
2330
+ x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
2241
2331
 
2242
2332
  # get layers to be executed
2243
2333
 
@@ -2359,9 +2449,6 @@ class AttentionLayers(Module):
2359
2449
  if not exists(first_cross_attn_inter) and layer_type == 'c':
2360
2450
  first_cross_attn_inter = inter
2361
2451
 
2362
- if self.resi_dual:
2363
- outer_residual = outer_residual + out * self.resi_dual_scale
2364
-
2365
2452
  if exists(post_branch_norm):
2366
2453
  out = post_branch_norm(out)
2367
2454
 
@@ -2395,10 +2482,7 @@ class AttentionLayers(Module):
2395
2482
  if is_multistream:
2396
2483
  x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2397
2484
 
2398
- if self.resi_dual:
2399
- x = x + final_norm(outer_residual)
2400
- else:
2401
- x = final_norm(x)
2485
+ x = final_norm(x)
2402
2486
 
2403
2487
  if not return_hiddens:
2404
2488
  return x
@@ -2444,7 +2528,7 @@ class PrefixDecoder(AttentionLayers):
2444
2528
  if isinstance(prefix_attn_len, int):
2445
2529
  prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
2446
2530
 
2447
- prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2531
+ prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2448
2532
  forwarded_mask = forwarded_mask | prefix_mask
2449
2533
 
2450
2534
  if exists(attn_mask):
@@ -2773,13 +2857,13 @@ class TransformerWrapper(Module):
2773
2857
  prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2774
2858
  assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2775
2859
 
2776
- x = torch.cat((prepend_embeds, x), dim = -2)
2860
+ x = cat((prepend_embeds, x), dim = -2)
2777
2861
 
2778
2862
  if exists(prepend_mask) or exists(mask):
2779
2863
  mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
2780
2864
  prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
2781
2865
 
2782
- mask = torch.cat((prepend_mask, mask), dim = -1)
2866
+ mask = cat((prepend_mask, mask), dim = -1)
2783
2867
 
2784
2868
  # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2785
2869
 
@@ -2945,7 +3029,7 @@ class TransformerWrapper(Module):
2945
3029
 
2946
3030
  if return_mems:
2947
3031
  hiddens = intermediates.hiddens
2948
- new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
3032
+ new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
2949
3033
  new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
2950
3034
 
2951
3035
  if not return_intermediates: