x-transformers 1.40.2__py3-none-any.whl → 1.40.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +4 -3
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.3.dist-info}/METADATA +1 -1
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.3.dist-info}/RECORD +6 -6
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.3.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.3.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.3.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
|
|
442
442
|
return bias
|
443
443
|
|
444
444
|
class AlibiPositionalBias(Module):
|
445
|
-
def __init__(self, heads, total_heads, **kwargs):
|
445
|
+
def __init__(self, heads, total_heads = None, **kwargs):
|
446
446
|
super().__init__()
|
447
447
|
self.heads = heads
|
448
|
-
self.total_heads = total_heads
|
448
|
+
self.total_heads = default(total_heads, heads)
|
449
449
|
|
450
450
|
slopes = Tensor(self._get_slopes(heads))
|
451
451
|
slopes = rearrange(slopes, 'h -> h 1 1')
|
@@ -1665,6 +1665,7 @@ class AttentionLayers(Module):
|
|
1665
1665
|
cache_age = 1,
|
1666
1666
|
return_hiddens = False,
|
1667
1667
|
rotary_pos_emb = None,
|
1668
|
+
attn_bias = None,
|
1668
1669
|
condition = None,
|
1669
1670
|
layers_execute_order: tuple[int, ...] | None = None
|
1670
1671
|
):
|
@@ -1818,7 +1819,7 @@ class AttentionLayers(Module):
|
|
1818
1819
|
block = partial(block, **block_forward_kwargs)
|
1819
1820
|
|
1820
1821
|
if layer_type == 'a':
|
1821
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
|
1822
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
|
1822
1823
|
elif layer_type == 'c':
|
1823
1824
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1824
1825
|
elif layer_type == 'f':
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=17wNbMh5Mt6cyrAvkBToIozmFt-p9ZBhQCqqlnDyHPI,84676
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.3.dist-info/METADATA,sha256=2Qb9WF8pa8bmqMI4VdO9f9bBEmKMcUatlbcU2SDJGwM,661
|
13
|
+
x_transformers-1.40.3.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|