x-transformers 1.42.8__py3-none-any.whl → 1.42.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -48,7 +48,7 @@ def align_right(t, lens, pad_id = 0):
48
48
  batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
49
49
  prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
50
50
 
51
- t = F.pad(t, (max_pad_len, 0), value = 0)
51
+ t = F.pad(t, (max_pad_len, 0), value = pad_id)
52
52
  offset = max_pad_len - pad_lens
53
53
 
54
54
  aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
@@ -452,13 +452,20 @@ class DynamicPositionBias(Module):
452
452
  return bias
453
453
 
454
454
  class AlibiPositionalBias(Module):
455
- def __init__(self, heads, total_heads = None, **kwargs):
455
+ def __init__(
456
+ self,
457
+ heads,
458
+ total_heads = None,
459
+ slopes: list[int] | None = None,
460
+ **kwargs
461
+ ):
456
462
  super().__init__()
457
463
  self.heads = heads
458
464
  self.total_heads = default(total_heads, heads)
459
465
 
460
- slopes = Tensor(self._get_slopes(heads))
466
+ slopes = Tensor(default(slopes, self._get_slopes(heads)))
461
467
  slopes = rearrange(slopes, 'h -> h 1 1')
468
+
462
469
  self.register_buffer('slopes', slopes, persistent = False)
463
470
  self.register_buffer('bias', None, persistent = False)
464
471
 
@@ -487,7 +494,10 @@ class AlibiPositionalBias(Module):
487
494
  h, device = self.total_heads, self.device
488
495
 
489
496
  pos_j = default(pos_j, pos_i)
490
- bias = -einx.subtract('... j, ... i -> ... 1 i j', pos_j, pos_i).abs()
497
+ bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
498
+
499
+ if bias.ndim == 3:
500
+ bias = rearrange(bias, 'b i j -> b 1 i j')
491
501
 
492
502
  bias = bias * self.slopes
493
503
  num_heads_unalibied = h - bias.shape[-3]
@@ -1531,8 +1541,9 @@ class AttentionLayers(Module):
1531
1541
  use_layerscale = False,
1532
1542
  layerscale_init_value = 0.,
1533
1543
  unet_skips = False,
1534
- reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1535
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1544
+ reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1545
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1546
+ rel_pos_kwargs: dict = dict(),
1536
1547
  **kwargs
1537
1548
  ):
1538
1549
  super().__init__()
@@ -1573,14 +1584,14 @@ class AttentionLayers(Module):
1573
1584
 
1574
1585
  if rel_pos_bias:
1575
1586
  assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1576
- self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
1587
+ self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
1577
1588
  elif dynamic_pos_bias:
1578
1589
  assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
1579
- self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
1590
+ self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
1580
1591
  elif alibi_pos_bias:
1581
1592
  alibi_num_heads = default(alibi_num_heads, heads)
1582
1593
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1583
- self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1594
+ self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
1584
1595
 
1585
1596
  assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1586
1597
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.8
3
+ Version: 1.42.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
2
  x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
3
- x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
3
+ x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=275B_yDHePxUvlLcMNgnCUmZ1qZEkwBrpk6IA8n-pnY,93550
9
+ x_transformers/x_transformers.py,sha256=VxdA44EYQhVH1Rp7wreJ83I2e0Ea7VN_bFRE-iDXOI8,93833
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.8.dist-info/METADATA,sha256=1d2BVA6iHKpT4UzbYxw16ijAFGJT-u29zTnYtV6Lp3w,689
14
- x_transformers-1.42.8.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
- x_transformers-1.42.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.8.dist-info/RECORD,,
12
+ x_transformers-1.42.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.10.dist-info/METADATA,sha256=UlOzLgz1fhOIaz3eKDhcHgVI81EZySt4Ko3BiKJ7Jok,690
14
+ x_transformers-1.42.10.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
15
+ x_transformers-1.42.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.10.dist-info/RECORD,,