x-transformers 1.40.6__py3-none-any.whl → 1.40.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -0
- x_transformers/x_transformers.py +45 -2
- {x_transformers-1.40.6.dist-info → x_transformers-1.40.8.dist-info}/METADATA +1 -1
- {x_transformers-1.40.6.dist-info → x_transformers-1.40.8.dist-info}/RECORD +7 -7
- {x_transformers-1.40.6.dist-info → x_transformers-1.40.8.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.6.dist-info → x_transformers-1.40.8.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.6.dist-info → x_transformers-1.40.8.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -22,6 +22,7 @@ class Intermediates:
|
|
22
22
|
qk_similarities: Tensor | None = None
|
23
23
|
pre_softmax_attn: Tensor | None = None
|
24
24
|
post_softmax_attn: Tensor | None = None
|
25
|
+
values: Tensor | None = None
|
25
26
|
cached_kv: Tuple[Tensor, Tensor] | None = None
|
26
27
|
layer_type: str | None = None
|
27
28
|
|
x_transformers/x_transformers.py
CHANGED
@@ -1114,6 +1114,7 @@ class Attention(Module):
|
|
1114
1114
|
mem_mask = None,
|
1115
1115
|
return_intermediates = False,
|
1116
1116
|
cache: Intermediates | None = None,
|
1117
|
+
value_residual = None
|
1117
1118
|
):
|
1118
1119
|
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)
|
1119
1120
|
|
@@ -1243,6 +1244,12 @@ class Attention(Module):
|
|
1243
1244
|
attn_bias = rel_pos(i, j)
|
1244
1245
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1245
1246
|
|
1247
|
+
# previous values passed in
|
1248
|
+
# https://arxiv.org/abs/2410.17897v1
|
1249
|
+
|
1250
|
+
if exists(value_residual):
|
1251
|
+
v = v + value_residual
|
1252
|
+
|
1246
1253
|
# attention is all we need
|
1247
1254
|
|
1248
1255
|
out, intermediates = self.attend(
|
@@ -1252,6 +1259,10 @@ class Attention(Module):
|
|
1252
1259
|
prev_attn = prev_attn
|
1253
1260
|
)
|
1254
1261
|
|
1262
|
+
# store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
|
1263
|
+
|
1264
|
+
intermediates.values = v
|
1265
|
+
|
1255
1266
|
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
1256
1267
|
|
1257
1268
|
if exists(r):
|
@@ -1354,6 +1365,7 @@ class AttentionLayers(Module):
|
|
1354
1365
|
layerscale_init_value = 0.,
|
1355
1366
|
unet_skips = False,
|
1356
1367
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1368
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
|
1357
1369
|
**kwargs
|
1358
1370
|
):
|
1359
1371
|
super().__init__()
|
@@ -1588,6 +1600,10 @@ class AttentionLayers(Module):
|
|
1588
1600
|
self.reinject_input = reinject_input
|
1589
1601
|
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None
|
1590
1602
|
|
1603
|
+
# add the value from the first self attention block to all latter projected self attention values as a residual
|
1604
|
+
|
1605
|
+
self.add_value_residual = add_value_residual
|
1606
|
+
|
1591
1607
|
# iterate and construct layers
|
1592
1608
|
|
1593
1609
|
for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
|
@@ -1787,6 +1803,11 @@ class AttentionLayers(Module):
|
|
1787
1803
|
|
1788
1804
|
skip_hiddens = []
|
1789
1805
|
|
1806
|
+
# for value residuals
|
1807
|
+
|
1808
|
+
first_self_attn_inter = None
|
1809
|
+
first_cross_attn_inter = None
|
1810
|
+
|
1790
1811
|
# go through the attention and feedforward layers
|
1791
1812
|
|
1792
1813
|
for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
@@ -1838,13 +1859,35 @@ class AttentionLayers(Module):
|
|
1838
1859
|
|
1839
1860
|
block = partial(block, **block_forward_kwargs)
|
1840
1861
|
|
1862
|
+
# handle maybe value residuals
|
1863
|
+
|
1864
|
+
maybe_self_attn_value_residual = None
|
1865
|
+
maybe_cross_attn_value_residual = None
|
1866
|
+
|
1867
|
+
if self.add_value_residual:
|
1868
|
+
if exists(first_self_attn_inter):
|
1869
|
+
maybe_self_attn_value_residual = first_self_attn_inter.values
|
1870
|
+
|
1871
|
+
if exists(first_cross_attn_inter):
|
1872
|
+
maybe_cross_attn_value_residual = first_cross_attn_inter.values
|
1873
|
+
|
1874
|
+
# forward depending on layer type
|
1875
|
+
|
1841
1876
|
if layer_type == 'a':
|
1842
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
|
1877
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
1843
1878
|
elif layer_type == 'c':
|
1844
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1879
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
1845
1880
|
elif layer_type == 'f':
|
1846
1881
|
out = block(x)
|
1847
1882
|
|
1883
|
+
# store first self or cross attention intermediate for value residual
|
1884
|
+
|
1885
|
+
if not exists(first_self_attn_inter) and layer_type == 'a':
|
1886
|
+
first_self_attn_inter = inter
|
1887
|
+
|
1888
|
+
if not exists(first_cross_attn_inter) and layer_type == 'c':
|
1889
|
+
first_cross_attn_inter = inter
|
1890
|
+
|
1848
1891
|
if self.resi_dual:
|
1849
1892
|
outer_residual = outer_residual + out * self.resi_dual_scale
|
1850
1893
|
|
@@ -1,15 +1,15 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=-MkQrSc37cTVDX7AOykxunYnqVtFlQ7lb0Cse5dsGWU,793
|
2
|
-
x_transformers/attend.py,sha256=
|
2
|
+
x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=DOJJCMMDOqDYKWy_IaG5IyKsXD3AW6amzfUgdAADOLY,10500
|
4
4
|
x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,6450
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=UyJ3-XoanCHOUZLoyp29jt-yNWbgc7pFqpfjOGFFEPY,87367
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.8.dist-info/METADATA,sha256=epku2YLENdljdFU7ge7hhUDXMujquHyb0jhoDSv9yFk,661
|
13
|
+
x_transformers-1.40.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|