x-transformers 1.40.7__py3-none-any.whl → 1.40.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +43 -10
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.9.dist-info}/METADATA +1 -1
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.9.dist-info}/RECORD +6 -6
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.9.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.9.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.7.dist-info → x_transformers-1.40.9.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -944,6 +944,8 @@ class Attention(Module):
|
|
944
944
|
cope_talking_heads = False,
|
945
945
|
softclamp_logits = False,
|
946
946
|
logit_softclamp_value = 50.,
|
947
|
+
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
948
|
+
neutreno_alpha = 0.4,
|
947
949
|
onnxable = False
|
948
950
|
):
|
949
951
|
super().__init__()
|
@@ -982,6 +984,11 @@ class Attention(Module):
|
|
982
984
|
|
983
985
|
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
984
986
|
|
987
|
+
# the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
|
988
|
+
|
989
|
+
self.neutreno_value_residual = neutreno_value_residual
|
990
|
+
self.neutreno_alpha = neutreno_alpha
|
991
|
+
|
985
992
|
# add GLU gating for aggregated values, from alphafold2
|
986
993
|
|
987
994
|
self.to_v_gate = None
|
@@ -1244,11 +1251,15 @@ class Attention(Module):
|
|
1244
1251
|
attn_bias = rel_pos(i, j)
|
1245
1252
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
|
1246
1253
|
|
1247
|
-
# previous values passed in
|
1248
|
-
# https://arxiv.org/abs/2410.17897v1
|
1254
|
+
# if previous values passed in for residual, either invoke resformer or neutreno
|
1249
1255
|
|
1250
1256
|
if exists(value_residual):
|
1251
|
-
|
1257
|
+
if self.neutreno_value_residual:
|
1258
|
+
diff_values = (value_residual - v) * self.neutreno_alpha
|
1259
|
+
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
|
1260
|
+
else:
|
1261
|
+
# https://arxiv.org/abs/2410.17897v1
|
1262
|
+
v = v + value_residual
|
1252
1263
|
|
1253
1264
|
# attention is all we need
|
1254
1265
|
|
@@ -1259,10 +1270,13 @@ class Attention(Module):
|
|
1259
1270
|
prev_attn = prev_attn
|
1260
1271
|
)
|
1261
1272
|
|
1262
|
-
# store the values for resformer
|
1273
|
+
# store the values for resformer or Neutreno
|
1263
1274
|
|
1264
1275
|
intermediates.values = v
|
1265
1276
|
|
1277
|
+
if exists(value_residual) and self.neutreno_value_residual:
|
1278
|
+
out = out + diff_values
|
1279
|
+
|
1266
1280
|
# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
|
1267
1281
|
|
1268
1282
|
if exists(r):
|
@@ -1365,7 +1379,7 @@ class AttentionLayers(Module):
|
|
1365
1379
|
layerscale_init_value = 0.,
|
1366
1380
|
unet_skips = False,
|
1367
1381
|
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
|
1368
|
-
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1382
|
+
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
|
1369
1383
|
**kwargs
|
1370
1384
|
):
|
1371
1385
|
super().__init__()
|
@@ -1378,6 +1392,7 @@ class AttentionLayers(Module):
|
|
1378
1392
|
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1379
1393
|
|
1380
1394
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1395
|
+
add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
|
1381
1396
|
|
1382
1397
|
self.dim = dim
|
1383
1398
|
self.causal = causal
|
@@ -1803,7 +1818,10 @@ class AttentionLayers(Module):
|
|
1803
1818
|
|
1804
1819
|
skip_hiddens = []
|
1805
1820
|
|
1821
|
+
# for value residuals
|
1822
|
+
|
1806
1823
|
first_self_attn_inter = None
|
1824
|
+
first_cross_attn_inter = None
|
1807
1825
|
|
1808
1826
|
# go through the attention and feedforward layers
|
1809
1827
|
|
@@ -1856,20 +1874,35 @@ class AttentionLayers(Module):
|
|
1856
1874
|
|
1857
1875
|
block = partial(block, **block_forward_kwargs)
|
1858
1876
|
|
1859
|
-
|
1860
|
-
|
1861
|
-
|
1877
|
+
# handle maybe value residuals
|
1878
|
+
|
1879
|
+
maybe_self_attn_value_residual = None
|
1880
|
+
maybe_cross_attn_value_residual = None
|
1881
|
+
|
1882
|
+
if self.add_value_residual:
|
1883
|
+
if exists(first_self_attn_inter):
|
1884
|
+
maybe_self_attn_value_residual = first_self_attn_inter.values
|
1885
|
+
|
1886
|
+
if exists(first_cross_attn_inter):
|
1887
|
+
maybe_cross_attn_value_residual = first_cross_attn_inter.values
|
1888
|
+
|
1889
|
+
# forward depending on layer type
|
1862
1890
|
|
1863
1891
|
if layer_type == 'a':
|
1864
|
-
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual =
|
1892
|
+
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
|
1865
1893
|
elif layer_type == 'c':
|
1866
|
-
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
|
1894
|
+
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
|
1867
1895
|
elif layer_type == 'f':
|
1868
1896
|
out = block(x)
|
1869
1897
|
|
1898
|
+
# store first self or cross attention intermediate for value residual
|
1899
|
+
|
1870
1900
|
if not exists(first_self_attn_inter) and layer_type == 'a':
|
1871
1901
|
first_self_attn_inter = inter
|
1872
1902
|
|
1903
|
+
if not exists(first_cross_attn_inter) and layer_type == 'c':
|
1904
|
+
first_cross_attn_inter = inter
|
1905
|
+
|
1873
1906
|
if self.resi_dual:
|
1874
1907
|
outer_residual = outer_residual + out * self.resi_dual_scale
|
1875
1908
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=NoWBoiz1t8_QytYo1T2YBFk-7H9s38k2t-EksxqUkMU,88072
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.9.dist-info/METADATA,sha256=xSxqFkhGfr5dU2xI0xo3UzlPMSuaaR4Rd2TrDpEyxcE,661
|
13
|
+
x_transformers-1.40.9.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.9.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|