x-transformers 1.42.8__py3-none-any.whl → 1.42.10__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -48,7 +48,7 @@ def align_right(t, lens, pad_id = 0):
48
48
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
49
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
50
 
51
- t = F.pad(t, (max_pad_len, 0), value = 0)
51
+ t = F.pad(t, (max_pad_len, 0), value = pad_id)
52
52
  offset = max_pad_len - pad_lens
53
53
 
54
54
  aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
@@ -452,13 +452,20 @@ class DynamicPositionBias(Module):
452
452
  return bias
453
453
 
454
454
  class AlibiPositionalBias(Module):
455
- def __init__(self, heads, total_heads = None, **kwargs):
455
+ def __init__(
456
+ self,
457
+ heads,
458
+ total_heads = None,
459
+ slopes: list[int] | None = None,
460
+ **kwargs
461
+ ):
456
462
  super().__init__()
457
463
  self.heads = heads
458
464
  self.total_heads = default(total_heads, heads)
459
465
 
460
- slopes = Tensor(self._get_slopes(heads))
466
+ slopes = Tensor(default(slopes, self._get_slopes(heads)))
461
467
  slopes = rearrange(slopes, 'h -> h 1 1')
468
+
462
469
  self.register_buffer('slopes', slopes, persistent = False)
463
470
  self.register_buffer('bias', None, persistent = False)
464
471
 
@@ -487,7 +494,10 @@ class AlibiPositionalBias(Module):
487
494
  h, device = self.total_heads, self.device
488
495
 
489
496
  pos_j = default(pos_j, pos_i)
490
- bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
497
+ bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
498
+
499
+ if bias.ndim == 3:
500
+ bias = rearrange(bias, 'b i j -> b 1 i j')
491
501
 
492
502
  bias = bias * self.slopes
493
503
  num_heads_unalibied = h - bias.shape[-3]
@@ -1531,8 +1541,9 @@ class AttentionLayers(Module):
1531
1541
  use_layerscale = False,
1532
1542
  layerscale_init_value = 0.,
1533
1543
  unet_skips = False,
1534
- reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1535
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1544
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1545
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1546
+ rel_pos_kwargs: dict = dict(),
1536
1547
  **kwargs
1537
1548
  ):
1538
1549
  super().__init__()
@@ -1573,14 +1584,14 @@ class AttentionLayers(Module):
1573
1584
 
1574
1585
  if rel_pos_bias:
1575
1586
  assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1576
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1587
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
1577
1588
  elif dynamic_pos_bias:
1578
1589
  assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1579
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1590
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
1580
1591
  elif alibi_pos_bias:
1581
1592
  alibi_num_heads = default(alibi_num_heads, heads)
1582
1593
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1583
- self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1594
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1584
1595
 
1585
1596
  assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1586
1597
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.8
3
+ Version: 1.42.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
- x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
3
+ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=275B_yDHePxUvlLcMNgnCUmZ1qZEkwBrpk6IA8n-pnY,93550
9
+ x_transformers/x_transformers.py,sha256=VxdA44EYQhVH1Rp7wreJ83I2e0Ea7VN_bFRE-iDXOI8,93833
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.8.dist-info/METADATA,sha256=1d2BVA6iHKpT4UzbYxw16ijAFGJT-u29zTnYtV6Lp3w,689
14
- x_transformers-1.42.8.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.8.dist-info/RECORD,,
12
+ x_transformers-1.42.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.10.dist-info/METADATA,sha256=UlOzLgz1fhOIaz3eKDhcHgVI81EZySt4Ko3BiKJ7Jok,690
14
+ x_transformers-1.42.10.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.10.dist-info/RECORD,,