x-transformers 1.40.2__py3-none-any.whl → 1.40.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +18 -3
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.4.dist-info}/METADATA +1 -1
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.4.dist-info}/RECORD +6 -6
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.2.dist-info → x_transformers-1.40.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -442,10 +442,10 @@ class DynamicPositionBias(Module):
|
|
442
442
|
return bias
|
443
443
|
|
444
444
|
class AlibiPositionalBias(Module):
|
445
|
-
def __init__(self, heads, total_heads, **kwargs):
|
445
|
+
def __init__(self, heads, total_heads = None, **kwargs):
|
446
446
|
super().__init__()
|
447
447
|
self.heads = heads
|
448
|
-
self.total_heads = total_heads
|
448
|
+
self.total_heads = default(total_heads, heads)
|
449
449
|
|
450
450
|
slopes = Tensor(self._get_slopes(heads))
|
451
451
|
slopes = rearrange(slopes, 'h -> h 1 1')
|
@@ -1353,6 +1353,7 @@ class AttentionLayers(Module):
|
|
1353
1353
|
use_layerscale = False,
|
1354
1354
|
layerscale_init_value = 0.,
|
1355
1355
|
unet_skips = False,
|
1356
|
+
reinject_input = True, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1356
1357
|
**kwargs
|
1357
1358
|
):
|
1358
1359
|
super().__init__()
|
@@ -1582,6 +1583,11 @@ class AttentionLayers(Module):
|
|
1582
1583
|
|
1583
1584
|
self.skip_combines = ModuleList([])
|
1584
1585
|
|
1586
|
+
# whether there is reinjection of input at every layer
|
1587
|
+
|
1588
|
+
self.reinject_input = reinject_input
|
1589
|
+
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
|
+
|
1585
1591
|
# iterate and construct layers
|
1586
1592
|
|
1587
1593
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1665,6 +1671,7 @@ class AttentionLayers(Module):
|
|
1665
1671
|
cache_age = 1,
|
1666
1672
|
return_hiddens = False,
|
1667
1673
|
rotary_pos_emb = None,
|
1674
|
+
attn_bias = None,
|
1668
1675
|
condition = None,
|
1669
1676
|
layers_execute_order: tuple[int, ...] | None = None
|
1670
1677
|
):
|
@@ -1765,6 +1772,11 @@ class AttentionLayers(Module):
|
|
1765
1772
|
|
1766
1773
|
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1767
1774
|
|
1775
|
+
# derived input for reinjection if needed
|
1776
|
+
|
1777
|
+
if self.reinject_input:
|
1778
|
+
inp_inject = self.reinject_input_proj(x)
|
1779
|
+
|
1768
1780
|
# store all hiddens for skips
|
1769
1781
|
|
1770
1782
|
skip_hiddens = []
|
@@ -1809,6 +1821,9 @@ class AttentionLayers(Module):
|
|
1809
1821
|
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1810
1822
|
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1811
1823
|
|
1824
|
+
if self.reinject_input:
|
1825
|
+
x = x + inp_inject
|
1826
|
+
|
1812
1827
|
if exists(pre_norm):
|
1813
1828
|
x = pre_norm(x)
|
1814
1829
|
|
@@ -1818,7 +1833,7 @@ class AttentionLayers(Module):
|
|
1818
1833
|
block = partial(block, **block_forward_kwargs)
|
1819
1834
|
|
1820
1835
|
if layer_type == 'a':
|
1821
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
|
1836
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
|
1822
1837
|
elif layer_type == 'c':
|
1823
1838
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1824
1839
|
elif layer_type == 'f':
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=g7cGYUcq344QZejqdB_PPHHz1O_zqmdplxgQA07hqks,85298
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.4.dist-info/METADATA,sha256=K7tcPd4ZY7qO773XErYQbik-XDEfLbFQ0lFYSfjAFvY,661
|
13
|
+
x_transformers-1.40.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|