x-transformers 1.44.8__py3-none-any.whl → 2.0.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -9,7 +9,7 @@ from packaging import version
9
9
  import torch
10
10
  from torch.amp import autocast
11
11
  import torch.nn.functional as F
12
- from torch import nn, einsum, Tensor
12
+ from torch import nn, einsum, Tensor, cat, stack, arange
13
13
  from torch.utils._pytree import tree_flatten, tree_unflatten
14
14
  from torch.nn import Module, ModuleList, ModuleDict
15
15
 
@@ -18,14 +18,22 @@ from collections import namedtuple
18
18
  from contextlib import nullcontext
19
19
  from dataclasses import dataclass
20
20
 
21
+ from loguru import logger
22
+
23
+ from x_transformers.attend import Attend, Intermediates
24
+ from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
+
21
26
  import einx
22
27
  from einops.layers.torch import Rearrange
23
28
  from einops import rearrange, repeat, reduce, pack, unpack
24
29
 
25
- from loguru import logger
30
+ # einstein notation
26
31
 
27
- from x_transformers.attend import Attend, Intermediates
28
- from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
32
+ # b - batch
33
+ # n - sequence
34
+ # d - feature dimension
35
+ # h - attention heads
36
+ # i, j - sequence (source, target)
29
37
 
30
38
  # constants
31
39
 
@@ -220,7 +228,7 @@ def dropout_seq(seq, mask, dropout):
220
228
  num_keep = max(1, int(keep_prob * n))
221
229
  keep_indices = logits.topk(num_keep, dim = 1).indices
222
230
 
223
- batch_indices = torch.arange(b, device = device)
231
+ batch_indices = arange(b, device = device)
224
232
  batch_indices = rearrange(batch_indices, 'b -> b 1')
225
233
 
226
234
  seq = seq[batch_indices, keep_indices]
@@ -228,7 +236,7 @@ def dropout_seq(seq, mask, dropout):
228
236
  if exists(mask):
229
237
  seq_counts = mask.sum(dim = -1)
230
238
  seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
231
- keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
239
+ keep_mask = arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1')
232
240
 
233
241
  mask = mask[batch_indices, keep_indices] & keep_mask
234
242
 
@@ -274,7 +282,7 @@ class AbsolutePositionalEmbedding(Module):
274
282
  assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
275
283
 
276
284
  if not exists(pos):
277
- pos = torch.arange(seq_len, device = device)
285
+ pos = arange(seq_len, device = device)
278
286
 
279
287
  if exists(seq_start_pos):
280
288
  pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
@@ -290,7 +298,7 @@ class ScaledSinusoidalEmbedding(Module):
290
298
  self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
291
299
 
292
300
  half_dim = dim // 2
293
- freq_seq = torch.arange(half_dim).float() / half_dim
301
+ freq_seq = arange(half_dim).float() / half_dim
294
302
  inv_freq = theta ** -freq_seq
295
303
  self.register_buffer('inv_freq', inv_freq, persistent = False)
296
304
 
@@ -298,13 +306,13 @@ class ScaledSinusoidalEmbedding(Module):
298
306
  seq_len, device = x.shape[1], x.device
299
307
 
300
308
  if not exists(pos):
301
- pos = torch.arange(seq_len, device = device)
309
+ pos = arange(seq_len, device = device)
302
310
 
303
311
  if exists(seq_start_pos):
304
312
  pos = pos - seq_start_pos[..., None]
305
313
 
306
314
  emb = einsum('i, j -> i j', pos, self.inv_freq)
307
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
315
+ emb = cat((emb.sin(), emb.cos()), dim = -1)
308
316
  return emb * self.scale
309
317
 
310
318
  class RelativePositionBias(Module):
@@ -344,8 +352,8 @@ class RelativePositionBias(Module):
344
352
 
345
353
  def forward(self, i, j):
346
354
  device = self.device
347
- q_pos = torch.arange(j - i, j, dtype = torch.long, device = device)
348
- k_pos = torch.arange(j, dtype = torch.long, device = device)
355
+ q_pos = arange(j - i, j, dtype = torch.long, device = device)
356
+ k_pos = arange(j, dtype = torch.long, device = device)
349
357
  rel_pos = einx.subtract('j, i -> i j', k_pos, q_pos)
350
358
  rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance)
351
359
  values = self.relative_attention_bias(rp_bucket)
@@ -376,7 +384,7 @@ class CoPE(Module):
376
384
  if not soft_onehot:
377
385
  return
378
386
 
379
- self.register_buffer('positions', torch.arange(max_pos))
387
+ self.register_buffer('positions', arange(max_pos))
380
388
 
381
389
  def forward(self, query, attn_logits):
382
390
 
@@ -445,13 +453,13 @@ class DynamicPositionBias(Module):
445
453
  n, device = j, self.device
446
454
 
447
455
  # get the (n x n) matrix of distances
448
- seq_arange = torch.arange(n, device = device)
449
- context_arange = torch.arange(n, device = device)
456
+ seq_arange = arange(n, device = device)
457
+ context_arange = arange(n, device = device)
450
458
  indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
451
459
  indices += (n - 1)
452
460
 
453
461
  # input to continuous positions MLP
454
- pos = torch.arange(-n + 1, n, device = device).float()
462
+ pos = arange(-n + 1, n, device = device).float()
455
463
  pos = rearrange(pos, '... -> ... 1')
456
464
 
457
465
  if self.log_distance:
@@ -525,8 +533,8 @@ class AlibiPositionalBias(Module):
525
533
  if exists(self.bias) and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i:
526
534
  return self.bias[..., -i:, -j:]
527
535
 
528
- seq_arange = torch.arange(j - i, j, device = device)
529
- context_arange = torch.arange(j, device = device)
536
+ seq_arange = arange(j - i, j, device = device)
537
+ context_arange = arange(j, device = device)
530
538
  bias = -einx.subtract('j, i -> 1 i j', context_arange, seq_arange).abs()
531
539
 
532
540
  bias = bias * self.slopes
@@ -642,7 +650,7 @@ class RotaryEmbedding(Module):
642
650
  # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
643
651
  base *= base_rescale_factor ** (dim / (dim - 2))
644
652
 
645
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
653
+ inv_freq = 1. / (base ** (arange(0, dim, 2).float() / dim))
646
654
  self.register_buffer('inv_freq', inv_freq)
647
655
 
648
656
  assert interpolation_factor >= 1.
@@ -652,7 +660,7 @@ class RotaryEmbedding(Module):
652
660
  self.register_buffer('scale', None)
653
661
  return
654
662
 
655
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
663
+ scale = (arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
656
664
 
657
665
  self.scale_base = scale_base
658
666
  self.register_buffer('scale', scale)
@@ -660,7 +668,7 @@ class RotaryEmbedding(Module):
660
668
  def forward_from_seq_len(self, seq_len):
661
669
  device = self.inv_freq.device
662
670
 
663
- t = torch.arange(seq_len, device = device)
671
+ t = arange(seq_len, device = device)
664
672
  return self.forward(t)
665
673
 
666
674
  @autocast('cuda', enabled = False)
@@ -671,7 +679,7 @@ class RotaryEmbedding(Module):
671
679
  t = rearrange(t, 'n -> 1 n')
672
680
 
673
681
  freqs = torch.einsum('b i , j -> b i j', t.type_as(self.inv_freq), self.inv_freq) / self.interpolation_factor
674
- freqs = torch.stack((freqs, freqs), dim = -1)
682
+ freqs = stack((freqs, freqs), dim = -1)
675
683
  freqs = rearrange(freqs, '... d r -> ... (d r)')
676
684
 
677
685
  if not exists(self.scale):
@@ -679,7 +687,7 @@ class RotaryEmbedding(Module):
679
687
 
680
688
  power = (t - (max_pos // 2)) / self.scale_base
681
689
  scale = self.scale ** rearrange(power, '... n -> ... n 1')
682
- scale = torch.stack((scale, scale), dim = -1)
690
+ scale = stack((scale, scale), dim = -1)
683
691
  scale = rearrange(scale, '... d r -> ... (d r)')
684
692
 
685
693
  return freqs, scale
@@ -687,7 +695,7 @@ class RotaryEmbedding(Module):
687
695
  def rotate_half(x):
688
696
  x = rearrange(x, '... (d r) -> ... d r', r = 2)
689
697
  x1, x2 = x.unbind(dim = -1)
690
- x = torch.stack((-x2, x1), dim = -1)
698
+ x = stack((-x2, x1), dim = -1)
691
699
  return rearrange(x, '... d r -> ... (d r)')
692
700
 
693
701
  @autocast('cuda', enabled = False)
@@ -703,7 +711,7 @@ def apply_rotary_pos_emb(t, freqs, scale = 1):
703
711
  # partial rotary embeddings, Wang et al. GPT-J
704
712
  t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
705
713
  t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
706
- out = torch.cat((t, t_unrotated), dim = -1)
714
+ out = cat((t, t_unrotated), dim = -1)
707
715
 
708
716
  return out.type(orig_dtype)
709
717
 
@@ -904,7 +912,7 @@ class HyperConnection(Module):
904
912
  init_alpha0 = torch.zeros((num_residual_streams, num_input_views))
905
913
  init_alpha0[layer_index % num_residual_streams, :] = 1.
906
914
 
907
- self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
915
+ self.static_alpha = nn.Parameter(cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))
908
916
 
909
917
  self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + num_input_views))
910
918
  self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
@@ -973,7 +981,7 @@ class ShiftTokens(Module):
973
981
  splitted = x.split(feats_per_shift, dim = -1)
974
982
  segments_to_shift, rest = splitted[:segments], splitted[segments:]
975
983
  segments_to_shift = [shift(*args, mask = mask) for args in zip(segments_to_shift, shifts)]
976
- x = torch.cat((*segments_to_shift, *rest), dim = -1)
984
+ x = cat((*segments_to_shift, *rest), dim = -1)
977
985
  return self.fn(x, **kwargs)
978
986
 
979
987
  class FoldAxially(Module):
@@ -1080,7 +1088,7 @@ class ConcatCombine(Module):
1080
1088
 
1081
1089
  def forward(self, x, prev_layers: list[Tensor]):
1082
1090
  skip = prev_layers[self.prev_layer_ind]
1083
- concatted_skip = torch.cat((skip, x), dim = -1)
1091
+ concatted_skip = cat((skip, x), dim = -1)
1084
1092
  return self.combine(concatted_skip)
1085
1093
 
1086
1094
  # feedforward
@@ -1189,12 +1197,10 @@ class Attention(Module):
1189
1197
  hybrid_fold_axial_dim: int | None = None,
1190
1198
  one_kv_head = False,
1191
1199
  kv_heads = None,
1192
- shared_kv = False,
1193
1200
  value_dim_head = None,
1194
1201
  dim_out = None,
1195
- tensor_product = False, # https://arxiv.org/abs/2208.06061
1196
1202
  add_zero_kv = False, # same as add_zero_attn in pytorch
1197
- rotary_embed_values = False,
1203
+ rotate_num_heads = None,
1198
1204
  data_dependent_alibi = False,
1199
1205
  data_dependent_alibi_per_row = False,
1200
1206
  data_dependent_alibi_per_row_dim_head = 8,
@@ -1205,12 +1211,15 @@ class Attention(Module):
1205
1211
  cope_talking_heads = False,
1206
1212
  softclamp_logits = False,
1207
1213
  logit_softclamp_value = 50.,
1208
- neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1209
- neutreno_alpha = 0.4,
1210
1214
  learned_value_residual_mix = False,
1211
- laser = False, # https://arxiv.org/abs/2411.03493v1
1215
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1212
1216
  laser_softclamp_value = 15.,
1213
1217
  qkv_receive_diff_residuals = False,
1218
+ use_latent_q = False,
1219
+ dim_latent_q = None,
1220
+ use_latent_kv = False,
1221
+ dim_latent_kv = None,
1222
+ latent_rope_subheads = None,
1214
1223
  onnxable = False,
1215
1224
  attend_sdp_kwargs: dict = dict(
1216
1225
  enable_flash = True,
@@ -1242,13 +1251,51 @@ class Attention(Module):
1242
1251
  v_dim = value_dim_head * kv_heads
1243
1252
  out_dim = value_dim_head * heads
1244
1253
 
1245
- self.to_q = LinearNoBias(dim, q_dim)
1246
- self.to_k = LinearNoBias(dim_kv, k_dim)
1254
+ # determine input dimensions to qkv based on whether intermediate latent q and kv are being used
1255
+ # for eventually supporting multi-latent attention (MLA)
1256
+
1257
+ self.to_latent_q = None
1258
+ self.to_latent_kv = None
1259
+ self.to_rotateable_k = None # for their "decoupled rope", subheads of keys that comes directly from base sequence (does not go through latents)
1260
+
1261
+ dim_q_input = dim
1262
+ dim_kv_input = dim_kv
1263
+
1264
+ if use_latent_q:
1265
+ assert exists(dim_latent_q)
1266
+ self.to_latent_q = LinearNoBias(dim, dim_latent_q)
1267
+ dim_q_input = dim_latent_q
1268
+
1269
+ if use_latent_kv:
1270
+ assert exists(dim_latent_kv)
1271
+ self.to_latent_kv = LinearNoBias(dim, dim_latent_kv)
1272
+ dim_kv_input = dim_latent_kv
1273
+
1274
+ if exists(latent_rope_subheads):
1275
+ assert not exists(rotate_num_heads)
1276
+ rotate_num_heads = latent_rope_subheads
1277
+
1278
+ k_dim = dim_head * (kv_heads - latent_rope_subheads)
1247
1279
 
1248
- # shared key / values, for further memory savings during inference
1280
+ self.to_rotateable_k = LinearNoBias(dim, dim_head * latent_rope_subheads)
1281
+ self.split_rotateable_k_heads = Rearrange('b n (h d) -> b h n d', h = latent_rope_subheads)
1249
1282
 
1250
- assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1251
- self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1283
+ self.use_latent_q = use_latent_q
1284
+ self.use_latent_kv = use_latent_kv
1285
+
1286
+ # query key projection
1287
+
1288
+ self.to_q = LinearNoBias(dim_q_input, q_dim)
1289
+ self.to_k = LinearNoBias(dim_kv_input, k_dim)
1290
+ self.to_v = LinearNoBias(dim_kv_input, v_dim)
1291
+
1292
+ # split and merge of attention heads
1293
+
1294
+ self.split_q_heads = Rearrange('b n (h d) -> b h n d', h = heads)
1295
+ self.split_k_heads = Rearrange('b n (h d) -> b h n d', d = dim_head)
1296
+ self.split_v_heads = Rearrange('b n (h d) -> b h n d', d = value_dim_head)
1297
+
1298
+ self.merge_heads = Rearrange('b h n d -> b n (h d)')
1252
1299
 
1253
1300
  # whether qkv receives different residual stream combinations from hyper connections
1254
1301
 
@@ -1259,15 +1306,6 @@ class Attention(Module):
1259
1306
  self.laser = laser
1260
1307
  self.laser_softclamp_value = laser_softclamp_value
1261
1308
 
1262
- # relations projection from tp-attention
1263
-
1264
- self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
1265
-
1266
- # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
1267
-
1268
- self.neutreno_value_residual = neutreno_value_residual
1269
- self.neutreno_alpha = neutreno_alpha
1270
-
1271
1309
  # add GLU gating for aggregated values, from alphafold2
1272
1310
 
1273
1311
  self.to_v_gate = None
@@ -1406,9 +1444,15 @@ class Attention(Module):
1406
1444
  dim_out = default(dim_out, dim)
1407
1445
  self.to_out = nn.Sequential(LinearNoBias(out_dim, dim_out * 2), nn.GLU()) if on_attn else LinearNoBias(out_dim, dim_out)
1408
1446
 
1409
- # whether to rotate positions into values, for absolute positions in addition to relative
1447
+ # the number of attention heads to rotate, for decoupled rope in multi-latent attention
1448
+
1449
+ rotate_num_heads = default(rotate_num_heads, heads)
1450
+
1451
+ assert 0 < rotate_num_heads <= heads
1452
+ is_partial_rotate_heads = rotate_num_heads < heads
1453
+ assert not (is_partial_rotate_heads and kv_heads < heads), 'grouped query attention not compatible with partial rotate heads (decoupled rope for multi-latent attention), yet'
1410
1454
 
1411
- self.rotary_embed_values = rotary_embed_values
1455
+ self.rotate_num_heads = rotate_num_heads
1412
1456
 
1413
1457
  # whether parent can kv cache
1414
1458
 
@@ -1438,47 +1482,79 @@ class Attention(Module):
1438
1482
  cache: Intermediates | None = None,
1439
1483
  value_residual = None
1440
1484
  ):
1441
- b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals
1485
+ b, n, h, kv_h, head_scale, num_mem_kv, device, has_context, qkv_receive_diff_residuals, is_multi_latent_attn = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context), self.qkv_receive_diff_residuals, self.use_latent_kv
1486
+
1487
+ # an interesting possibility with hyper connections
1488
+ # having queries, keys, values be routed from different layers
1442
1489
 
1443
1490
  assert not (qkv_receive_diff_residuals and has_context), 'qkv receiving different sequences can only be used for self attention'
1444
1491
 
1445
1492
  if qkv_receive_diff_residuals:
1446
- assert not exists(self.to_r)
1493
+ assert x.ndim == 4 and x.shape[0] == 3
1447
1494
 
1448
1495
  q_input, k_input, v_input = x
1449
1496
  else:
1450
1497
  kv_input = default(context, x)
1451
-
1452
- q_input = x
1453
- k_input = kv_input
1454
- v_input = kv_input
1455
- r_input = x
1498
+ q_input, k_input, v_input = x, kv_input, kv_input
1456
1499
 
1457
1500
  if exists(mem):
1458
1501
  k_input, mem_packed_shape = pack([mem, k_input], 'b * d')
1459
1502
  v_input, _ = pack([mem, v_input], 'b * d')
1460
1503
 
1504
+ # multi-latent attention logic
1505
+ # https://arxiv.org/abs/2405.04434 - Deepseek-AI team
1506
+
1507
+ k_sub_heads = None # the rotateable subheads of keys derived from base sequence
1508
+
1509
+ if self.use_latent_q:
1510
+ q_input = self.to_latent_q(q_input)
1511
+
1512
+ if is_multi_latent_attn:
1513
+ assert not qkv_receive_diff_residuals
1514
+ needs_k_sub_heads = exists(self.to_rotateable_k)
1515
+
1516
+ latent_kv_input = self.to_latent_kv(k_input)
1517
+
1518
+ if needs_k_sub_heads:
1519
+ rotateable_k = self.to_rotateable_k(k_input)
1520
+ k_sub_heads = self.split_rotateable_k_heads(rotateable_k)
1521
+
1522
+ if exists(cache):
1523
+ cached_latent_kv, maybe_cached_k_sub_heads = cache.cached_kv
1524
+ latent_kv_input = cat((cached_latent_kv, latent_kv_input), dim = -2)
1525
+
1526
+ if exists(maybe_cached_k_sub_heads):
1527
+ k_sub_heads = cat((maybe_cached_k_sub_heads, k_sub_heads), dim = -2)
1528
+
1529
+ if return_intermediates:
1530
+ cached_kv = (latent_kv_input, k_sub_heads)
1531
+
1532
+ k_input = v_input = latent_kv_input
1533
+
1534
+ # query, key, value projection
1535
+
1461
1536
  q = self.to_q(q_input)
1462
1537
  k = self.to_k(k_input)
1463
- v = self.to_v(v_input) if exists(self.to_v) else k
1464
- r = self.to_r(r_input) if exists(self.to_r) else None
1538
+ v = self.to_v(v_input)
1539
+
1540
+ q = self.split_q_heads(q)
1541
+ k = self.split_k_heads(k)
1542
+ v = self.split_v_heads(v)
1465
1543
 
1466
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
1544
+ # take care of decoupled rope from multi-latent attention
1467
1545
 
1468
- k, v, r = tuple(maybe(rearrange)(t, 'b n (h d) -> b h n d', h = kv_h) for t in (k, v, r))
1546
+ if exists(k_sub_heads):
1547
+ k = cat((k, k_sub_heads), dim = 1)
1469
1548
 
1470
- # if previous values passed in for residual, either invoke resformer or neutreno
1549
+ # if previous values passed in for residual, either invoke resformer
1471
1550
 
1472
1551
  orig_values = v
1473
1552
 
1553
+ # https://arxiv.org/abs/2410.17897v1
1554
+
1474
1555
  if exists(value_residual):
1475
- if self.neutreno_value_residual:
1476
- diff_values = (value_residual - v) * self.neutreno_alpha
1477
- diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1478
- else:
1479
- # https://arxiv.org/abs/2410.17897v1
1480
- value_residual_mix = self.to_value_residual_mix(q_input)
1481
- v = v * value_residual_mix + value_residual * (1. - value_residual_mix)
1556
+ value_residual_mix = self.to_value_residual_mix(q_input)
1557
+ v = value_residual.lerp(v, value_residual_mix)
1482
1558
 
1483
1559
  # qk normalization
1484
1560
 
@@ -1492,28 +1568,36 @@ class Attention(Module):
1492
1568
 
1493
1569
  # take care of caching
1494
1570
 
1495
- if exists(cache):
1496
- ck, cv = cache.cached_kv
1571
+ if not is_multi_latent_attn:
1572
+ if exists(cache):
1573
+ ck, cv = cache.cached_kv
1497
1574
 
1498
- if exists(mem):
1499
- mk, k = unpack(k, mem_packed_shape, 'b h * d')
1500
- mv, v = unpack(v, mem_packed_shape, 'b h * d')
1575
+ if exists(mem):
1576
+ mk, k = unpack(k, mem_packed_shape, 'b h * d')
1577
+ mv, v = unpack(v, mem_packed_shape, 'b h * d')
1501
1578
 
1502
- k = torch.cat((ck, k), dim = -2)
1503
- v = torch.cat((cv, v), dim = -2)
1579
+ k = cat((ck, k), dim = -2)
1580
+ v = cat((cv, v), dim = -2)
1504
1581
 
1505
- if exists(mem):
1506
- k = torch.cat((mk, k), dim = -2)
1507
- v = torch.cat((mv, v), dim = -2)
1582
+ if exists(mem):
1583
+ k = cat((mk, k), dim = -2)
1584
+ v = cat((mv, v), dim = -2)
1508
1585
 
1509
- if return_intermediates:
1510
- mem_len = mem.shape[-2] if exists(mem) else 0
1511
- cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1586
+ if return_intermediates:
1587
+ mem_len = mem.shape[-2] if exists(mem) else 0
1588
+ cached_kv = (k[..., mem_len:, :], v[..., mem_len:, :])
1512
1589
 
1513
1590
  if exists(rotary_pos_emb):
1591
+ rotate_num_heads = self.rotate_num_heads
1592
+ partial_rotate_heads = rotate_num_heads < h
1593
+
1514
1594
  freqs, xpos_scale = rotary_pos_emb
1515
1595
  q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.)
1516
1596
 
1597
+ if partial_rotate_heads:
1598
+ q_rest, q = q[:, :-rotate_num_heads], q[:, -rotate_num_heads:]
1599
+ k_rest, k = k[:, :-rotate_num_heads], k[:, -rotate_num_heads:]
1600
+
1517
1601
  q = apply_rotary_pos_emb(q, freqs, q_xpos_scale)
1518
1602
 
1519
1603
  if has_context:
@@ -1524,8 +1608,9 @@ class Attention(Module):
1524
1608
 
1525
1609
  k = apply_rotary_pos_emb(k, freqs, k_xpos_scale)
1526
1610
 
1527
- if self.rotary_embed_values:
1528
- v = apply_rotary_pos_emb(v, freqs, k_xpos_scale)
1611
+ if partial_rotate_heads:
1612
+ q = cat((q_rest, q), dim = 1)
1613
+ k = cat((k_rest, k), dim = 1)
1529
1614
 
1530
1615
  input_mask = context_mask
1531
1616
 
@@ -1540,7 +1625,7 @@ class Attention(Module):
1540
1625
  elif not exists(input_mask):
1541
1626
  input_mask = pad_at_dim(mem_mask, (0, seq_len), dim = -1, value = True)
1542
1627
  else:
1543
- input_mask = torch.cat((mem_mask, input_mask), dim = -1)
1628
+ input_mask = cat((mem_mask, input_mask), dim = -1)
1544
1629
 
1545
1630
  # i, j determined for relative positional bias, excluding memory key / values
1546
1631
 
@@ -1555,8 +1640,8 @@ class Attention(Module):
1555
1640
  mem_k = l2norm(mem_k)
1556
1641
  mem_k = mem_k * self.qk_norm_k_scale
1557
1642
 
1558
- k = torch.cat((mem_k, k), dim = -2)
1559
- v = torch.cat((mem_v, v), dim = -2)
1643
+ k = cat((mem_k, k), dim = -2)
1644
+ v = cat((mem_v, v), dim = -2)
1560
1645
 
1561
1646
  if exists(input_mask):
1562
1647
  input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)
@@ -1580,8 +1665,8 @@ class Attention(Module):
1580
1665
  masks.append(~attn_mask)
1581
1666
 
1582
1667
  if exists(self.max_attend_past):
1583
- range_q = torch.arange(j - i, j, device = device)
1584
- range_k = torch.arange(j, device = device)
1668
+ range_q = arange(j - i, j, device = device)
1669
+ range_k = arange(j, device = device)
1585
1670
  dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
1586
1671
  max_attend_past_mask = dist > self.max_attend_past
1587
1672
  max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
@@ -1629,18 +1714,10 @@ class Attention(Module):
1629
1714
  if self.laser:
1630
1715
  out = log(out)
1631
1716
 
1632
- # store the values for resformer or Neutreno
1717
+ # store the values for resformer
1633
1718
 
1634
1719
  intermediates.values = orig_values
1635
1720
 
1636
- if exists(value_residual) and self.neutreno_value_residual:
1637
- out = out + diff_values
1638
-
1639
- # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1640
-
1641
- if exists(r):
1642
- out = out * r + out
1643
-
1644
1721
  # normformer scaling of heads
1645
1722
 
1646
1723
  if head_scale:
@@ -1654,7 +1731,7 @@ class Attention(Module):
1654
1731
 
1655
1732
  # merge heads
1656
1733
 
1657
- out = rearrange(out, 'b h n d -> b n (h d)')
1734
+ out = self.merge_heads(out)
1658
1735
 
1659
1736
  # hybrid module
1660
1737
 
@@ -1747,8 +1824,6 @@ class AttentionLayers(Module):
1747
1824
  sandwich_norm = False,
1748
1825
  softclamp_output = False,
1749
1826
  softclamp_output_value = 30.,
1750
- resi_dual = False,
1751
- resi_dual_scale = 1.,
1752
1827
  zero_init_branch_output = False,
1753
1828
  layer_dropout = 0.,
1754
1829
  cross_attn_tokens_dropout = 0.,
@@ -1775,12 +1850,9 @@ class AttentionLayers(Module):
1775
1850
 
1776
1851
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1777
1852
  data_dependent_alibi = attn_kwargs.get('data_dependent_alibi', False)
1778
- neutreno_value_residual = attn_kwargs.get('neutreno_value_residual', False)
1779
1853
 
1780
1854
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1781
1855
 
1782
- add_value_residual |= neutreno_value_residual
1783
-
1784
1856
  self.dim = dim
1785
1857
  self.causal = causal
1786
1858
  self.layers = ModuleList([])
@@ -1831,19 +1903,11 @@ class AttentionLayers(Module):
1831
1903
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1832
1904
  self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1833
1905
 
1834
- assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1835
1906
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1836
1907
 
1837
- if resi_dual:
1838
- pre_norm = False
1839
-
1840
1908
  self.pre_norm = pre_norm
1841
1909
  self.sandwich_norm = sandwich_norm
1842
1910
 
1843
- self.resi_dual = resi_dual
1844
- assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.'
1845
- self.resi_dual_scale = resi_dual_scale
1846
-
1847
1911
  self.residual_attn = residual_attn
1848
1912
  self.cross_residual_attn = cross_residual_attn
1849
1913
  assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention'
@@ -2002,7 +2066,7 @@ class AttentionLayers(Module):
2002
2066
 
2003
2067
  # whether it has post norm
2004
2068
 
2005
- self.final_norm = norm_fn() if pre_norm or resi_dual else nn.Identity()
2069
+ self.final_norm = norm_fn() if pre_norm else nn.Identity()
2006
2070
 
2007
2071
  # whether unet or not
2008
2072
 
@@ -2175,7 +2239,7 @@ class AttentionLayers(Module):
2175
2239
  # handle left padded sequences
2176
2240
 
2177
2241
  if exists(seq_start_pos):
2178
- seq_arange = torch.arange(x.shape[-2], device = x.device, dtype = torch.long)
2242
+ seq_arange = arange(x.shape[-2], device = x.device, dtype = torch.long)
2179
2243
  left_pad_mask = seq_arange >= seq_start_pos[..., None]
2180
2244
 
2181
2245
  if exists(self_attn_kv_mask):
@@ -2193,7 +2257,7 @@ class AttentionLayers(Module):
2193
2257
  mem_len = maybe_mem.shape[1] if exists(maybe_mem) else 0
2194
2258
 
2195
2259
  if not exists(pos):
2196
- pos = torch.arange(x.shape[1] + mem_len, device = x.device) - mem_len
2260
+ pos = arange(x.shape[1] + mem_len, device = x.device) - mem_len
2197
2261
 
2198
2262
  rotary_pos_emb = self.rotary_pos_emb(pos)
2199
2263
 
@@ -2213,7 +2277,7 @@ class AttentionLayers(Module):
2213
2277
  attn_cache = []
2214
2278
 
2215
2279
  if exists(cache):
2216
- assert not self.training and self.causal and not any([*map(exists, (mask, attn_mask))])
2280
+ assert self.causal and not any([*map(exists, (mask, attn_mask))])
2217
2281
 
2218
2282
  if exists(context):
2219
2283
  context = context[:, :0]
@@ -2231,13 +2295,7 @@ class AttentionLayers(Module):
2231
2295
  is_multistream = streams > 1
2232
2296
 
2233
2297
  if is_multistream:
2234
- x = repeat(x, 'b n d -> b n s d', s = streams)
2235
- x = x + self.stream_emb
2236
- x = rearrange(x, 'b n s d -> (b s) n d')
2237
-
2238
- # outer residual - for resiDual paper
2239
-
2240
- outer_residual = x * self.resi_dual_scale
2298
+ x = einx.add('b n d, s d -> (b s) n d', x, self.stream_emb)
2241
2299
 
2242
2300
  # get layers to be executed
2243
2301
 
@@ -2359,9 +2417,6 @@ class AttentionLayers(Module):
2359
2417
  if not exists(first_cross_attn_inter) and layer_type == 'c':
2360
2418
  first_cross_attn_inter = inter
2361
2419
 
2362
- if self.resi_dual:
2363
- outer_residual = outer_residual + out * self.resi_dual_scale
2364
-
2365
2420
  if exists(post_branch_norm):
2366
2421
  out = post_branch_norm(out)
2367
2422
 
@@ -2395,10 +2450,7 @@ class AttentionLayers(Module):
2395
2450
  if is_multistream:
2396
2451
  x = reduce(x, '(b s) n d -> b n d', 'sum', s = streams)
2397
2452
 
2398
- if self.resi_dual:
2399
- x = x + final_norm(outer_residual)
2400
- else:
2401
- x = final_norm(x)
2453
+ x = final_norm(x)
2402
2454
 
2403
2455
  if not return_hiddens:
2404
2456
  return x
@@ -2444,7 +2496,7 @@ class PrefixDecoder(AttentionLayers):
2444
2496
  if isinstance(prefix_attn_len, int):
2445
2497
  prefix_attn_len = torch.full((b,), prefix_attn_len, device = device)
2446
2498
 
2447
- prefix_mask = torch.arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2499
+ prefix_mask = arange(n, device = device) < rearrange(prefix_attn_len, 'b -> b 1 1 1')
2448
2500
  forwarded_mask = forwarded_mask | prefix_mask
2449
2501
 
2450
2502
  if exists(attn_mask):
@@ -2773,13 +2825,13 @@ class TransformerWrapper(Module):
2773
2825
  prepend_seq, prepend_dim = prepend_embeds.shape[1:]
2774
2826
  assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions'
2775
2827
 
2776
- x = torch.cat((prepend_embeds, x), dim = -2)
2828
+ x = cat((prepend_embeds, x), dim = -2)
2777
2829
 
2778
2830
  if exists(prepend_mask) or exists(mask):
2779
2831
  mask = default(mask, lambda: torch.ones((b, n), device = device, dtype = torch.bool))
2780
2832
  prepend_mask = default(prepend_mask, lambda: torch.ones((b, prepend_seq), device = device, dtype = torch.bool))
2781
2833
 
2782
- mask = torch.cat((prepend_mask, mask), dim = -1)
2834
+ mask = cat((prepend_mask, mask), dim = -1)
2783
2835
 
2784
2836
  # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model
2785
2837
 
@@ -2945,7 +2997,7 @@ class TransformerWrapper(Module):
2945
2997
 
2946
2998
  if return_mems:
2947
2999
  hiddens = intermediates.hiddens
2948
- new_mems = [torch.cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
3000
+ new_mems = [cat(pair, dim = -2) for pair in zip(mems, hiddens)] if exists(mems) else hiddens
2949
3001
  new_mems = [t[..., -self.max_mem_len:, :].detach() for t in new_mems]
2950
3002
 
2951
3003
  if not return_intermediates:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: x-transformers
3
- Version: 1.44.8
3
+ Version: 2.0.0
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=JKeK_639w-6KRojiLLmjRiPPUjitC1TIqAqeJodB0qo,104726
9
+ x_transformers/x_transformers.py,sha256=6pDFK-WzsW1ay75AcaHsFYOMpNX0CjDiq7Y0-xSpM0s,106174
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.44.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.44.8.dist-info/METADATA,sha256=x61kzlTQZsj8BDp2wWvS31h2AI7s6jhnyTjxHZkg62I,924
14
- x_transformers-1.44.8.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
- x_transformers-1.44.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.44.8.dist-info/RECORD,,
12
+ x_transformers-2.0.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-2.0.0.dist-info/METADATA,sha256=FkS_HmL6RPLni_t-Gg9t_8BA703rNv2CJppxule1W8A,923
14
+ x_transformers-2.0.0.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
15
+ x_transformers-2.0.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-2.0.0.dist-info/RECORD,,