x-transformers 1.40.5__py3-none-any.whl → 1.40.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/x_transformers.py +26 -1
- x_transformers/xval.py +2 -1
- {x_transformers-1.40.5.dist-info → x_transformers-1.40.7.dist-info}/METADATA +1 -1
- {x_transformers-1.40.5.dist-info → x_transformers-1.40.7.dist-info}/RECORD +8 -8
- {x_transformers-1.40.5.dist-info → x_transformers-1.40.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.5.dist-info → x_transformers-1.40.7.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.5.dist-info → x_transformers-1.40.7.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -22,6 +22,7 @@ class Intermediates:
|
|
22
22
|
qk_similarities: Tensor | None = None
|
23
23
|
pre_softmax_attn: Tensor | None = None
|
24
24
|
post_softmax_attn: Tensor | None = None
|
25
|
+
values: Tensor | None = None
|
25
26
|
cached_kv: Tuple[Tensor, Tensor] | None = None
|
26
27
|
layer_type: str | None = None
|
27
28
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1114,6 +1114,7 @@ class Attention(Module):
|
|
1114
1114
|
mem_mask = None,
|
1115
1115
|
return_intermediates = False,
|
1116
1116
|
cache: Intermediates | None = None,
|
1117
|
+
value_residual = None
|
1117
1118
|
):
|
1118
1119
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
1119
1120
|
|
@@ -1243,6 +1244,12 @@ class Attention(Module):
|
|
1243
1244
|
attn_bias = rel_pos(i, j)
|
1244
1245
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1245
1246
|
|
1247
|
+
# previous values passed in
|
1248
|
+
# https://arxiv.org/abs/2410.17897v1
|
1249
|
+
|
1250
|
+
if exists(value_residual):
|
1251
|
+
v = v + value_residual
|
1252
|
+
|
1246
1253
|
# attention is all we need
|
1247
1254
|
|
1248
1255
|
out, intermediates = self.attend(
|
@@ -1252,6 +1259,10 @@ class Attention(Module):
|
|
1252
1259
|
prev_attn = prev_attn
|
1253
1260
|
)
|
1254
1261
|
|
1262
|
+
# store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
|
1263
|
+
|
1264
|
+
intermediates.values = v
|
1265
|
+
|
1255
1266
|
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
1256
1267
|
|
1257
1268
|
if exists(r):
|
@@ -1354,6 +1365,7 @@ class AttentionLayers(Module):
|
|
1354
1365
|
layerscale_init_value = 0.,
|
1355
1366
|
unet_skips = False,
|
1356
1367
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1368
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
|
1357
1369
|
**kwargs
|
1358
1370
|
):
|
1359
1371
|
super().__init__()
|
@@ -1588,6 +1600,10 @@ class AttentionLayers(Module):
|
|
1588
1600
|
self.reinject_input = reinject_input
|
1589
1601
|
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
1602
|
|
1603
|
+
# add the value from the first self attention block to all latter projected self attention values as a residual
|
1604
|
+
|
1605
|
+
self.add_value_residual = add_value_residual
|
1606
|
+
|
1591
1607
|
# iterate and construct layers
|
1592
1608
|
|
1593
1609
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1787,6 +1803,8 @@ class AttentionLayers(Module):
|
|
1787
1803
|
|
1788
1804
|
skip_hiddens = []
|
1789
1805
|
|
1806
|
+
first_self_attn_inter = None
|
1807
|
+
|
1790
1808
|
# go through the attention and feedforward layers
|
1791
1809
|
|
1792
1810
|
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
@@ -1838,13 +1856,20 @@ class AttentionLayers(Module):
|
|
1838
1856
|
|
1839
1857
|
block = partial(block, **block_forward_kwargs)
|
1840
1858
|
|
1859
|
+
maybe_value_residual = None
|
1860
|
+
if self.add_value_residual and exists(first_self_attn_inter):
|
1861
|
+
maybe_value_residual = first_self_attn_inter.values
|
1862
|
+
|
1841
1863
|
if layer_type == 'a':
|
1842
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
|
1864
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_value_residual, return_intermediates = True)
|
1843
1865
|
elif layer_type == 'c':
|
1844
1866
|
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1845
1867
|
elif layer_type == 'f':
|
1846
1868
|
out = block(x)
|
1847
1869
|
|
1870
|
+
if not exists(first_self_attn_inter) and layer_type == 'a':
|
1871
|
+
first_self_attn_inter = inter
|
1872
|
+
|
1848
1873
|
if self.resi_dual:
|
1849
1874
|
outer_residual = outer_residual + out * self.resi_dual_scale
|
1850
1875
|
|
x_transformers/xval.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=azvG00DKeg4s0tH28QeDhU6X2GaeFjq3-0RMxgVt408,86715
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
|
-
x_transformers/xval.py,sha256=
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
10
|
+
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
+
x_transformers-1.40.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.7.dist-info/METADATA,sha256=mhcww_WPNU2q-piUEvoDEXwForJdgk9s_K6iCQVFDqo,661
|
13
|
+
x_transformers-1.40.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|