x-transformers 1.42.8__py3-none-any.whl → 1.42.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/autoregressive_wrapper.py +1 -1
- x_transformers/x_transformers.py +19 -8
- {x_transformers-1.42.8.dist-info → x_transformers-1.42.10.dist-info}/METADATA +1 -1
- {x_transformers-1.42.8.dist-info → x_transformers-1.42.10.dist-info}/RECORD +7 -7
- {x_transformers-1.42.8.dist-info → x_transformers-1.42.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.8.dist-info → x_transformers-1.42.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.8.dist-info → x_transformers-1.42.10.dist-info}/top_level.txt +0 -0
@@ -48,7 +48,7 @@ def align_right(t, lens, pad_id = 0):
|
|
48
48
|
batch_arange = torch.arange(batch, device = device, dtype = torch.long)[..., None]
|
49
49
|
prompt_len_arange = torch.arange(seq_len, device = device, dtype = torch.long)
|
50
50
|
|
51
|
-
t = F.pad(t, (max_pad_len, 0), value =
|
51
|
+
t = F.pad(t, (max_pad_len, 0), value = pad_id)
|
52
52
|
offset = max_pad_len - pad_lens
|
53
53
|
|
54
54
|
aligned = t[batch_arange, prompt_len_arange + offset[..., None]]
|
x_transformers/x_transformers.py
CHANGED
@@ -452,13 +452,20 @@ class DynamicPositionBias(Module):
|
|
452
452
|
return bias
|
453
453
|
|
454
454
|
class AlibiPositionalBias(Module):
|
455
|
-
def __init__(
|
455
|
+
def __init__(
|
456
|
+
self,
|
457
|
+
heads,
|
458
|
+
total_heads = None,
|
459
|
+
slopes: list[int] | None = None,
|
460
|
+
**kwargs
|
461
|
+
):
|
456
462
|
super().__init__()
|
457
463
|
self.heads = heads
|
458
464
|
self.total_heads = default(total_heads, heads)
|
459
465
|
|
460
|
-
slopes = Tensor(self._get_slopes(heads))
|
466
|
+
slopes = Tensor(default(slopes, self._get_slopes(heads)))
|
461
467
|
slopes = rearrange(slopes, 'h -> h 1 1')
|
468
|
+
|
462
469
|
self.register_buffer('slopes', slopes, persistent = False)
|
463
470
|
self.register_buffer('bias', None, persistent = False)
|
464
471
|
|
@@ -487,7 +494,10 @@ class AlibiPositionalBias(Module):
|
|
487
494
|
h, device = self.total_heads, self.device
|
488
495
|
|
489
496
|
pos_j = default(pos_j, pos_i)
|
490
|
-
bias = -einx.subtract('... j, ... i -> ...
|
497
|
+
bias = -einx.subtract('... j, ... i -> ... i j', pos_j, pos_i).abs()
|
498
|
+
|
499
|
+
if bias.ndim == 3:
|
500
|
+
bias = rearrange(bias, 'b i j -> b 1 i j')
|
491
501
|
|
492
502
|
bias = bias * self.slopes
|
493
503
|
num_heads_unalibied = h - bias.shape[-3]
|
@@ -1531,8 +1541,9 @@ class AttentionLayers(Module):
|
|
1531
1541
|
use_layerscale = False,
|
1532
1542
|
layerscale_init_value = 0.,
|
1533
1543
|
unet_skips = False,
|
1534
|
-
reinject_input = False,
|
1535
|
-
add_value_residual = False,
|
1544
|
+
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1545
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1546
|
+
rel_pos_kwargs: dict = dict(),
|
1536
1547
|
**kwargs
|
1537
1548
|
):
|
1538
1549
|
super().__init__()
|
@@ -1573,14 +1584,14 @@ class AttentionLayers(Module):
|
|
1573
1584
|
|
1574
1585
|
if rel_pos_bias:
|
1575
1586
|
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
1576
|
-
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
1587
|
+
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance, **rel_pos_kwargs)
|
1577
1588
|
elif dynamic_pos_bias:
|
1578
1589
|
assert not flash_attn, 'flash attention not compatible with dynamic positional bias'
|
1579
|
-
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
|
1590
|
+
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm, **rel_pos_kwargs)
|
1580
1591
|
elif alibi_pos_bias:
|
1581
1592
|
alibi_num_heads = default(alibi_num_heads, heads)
|
1582
1593
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1583
|
-
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
1594
|
+
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads, **rel_pos_kwargs)
|
1584
1595
|
|
1585
1596
|
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1586
1597
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
2
|
x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
|
3
|
-
x_transformers/autoregressive_wrapper.py,sha256=
|
3
|
+
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=VxdA44EYQhVH1Rp7wreJ83I2e0Ea7VN_bFRE-iDXOI8,93833
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.10.dist-info/METADATA,sha256=UlOzLgz1fhOIaz3eKDhcHgVI81EZySt4Ko3BiKJ7Jok,690
|
14
|
+
x_transformers-1.42.10.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
15
|
+
x_transformers-1.42.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|