x-transformers 1.40.7__py3-none-any.whl → 1.40.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -944,6 +944,8 @@ class Attention(Module):
944
944
  cope_talking_heads = False,
945
945
  softclamp_logits = False,
946
946
  logit_softclamp_value = 50.,
947
+ neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
948
+ neutreno_alpha = 0.4,
947
949
  onnxable = False
948
950
  ):
949
951
  super().__init__()
@@ -982,6 +984,11 @@ class Attention(Module):
982
984
 
983
985
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
984
986
 
987
+ # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
988
+
989
+ self.neutreno_value_residual = neutreno_value_residual
990
+ self.neutreno_alpha = neutreno_alpha
991
+
985
992
  # add GLU gating for aggregated values, from alphafold2
986
993
 
987
994
  self.to_v_gate = None
@@ -1244,11 +1251,15 @@ class Attention(Module):
1244
1251
  attn_bias = rel_pos(i, j)
1245
1252
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1246
1253
 
1247
- # previous values passed in
1248
- # https://arxiv.org/abs/2410.17897v1
1254
+ # if previous values passed in for residual, either invoke resformer or neutreno
1249
1255
 
1250
1256
  if exists(value_residual):
1251
- v = v + value_residual
1257
+ if self.neutreno_value_residual:
1258
+ diff_values = (value_residual - v) * self.neutreno_alpha
1259
+ diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1260
+ else:
1261
+ # https://arxiv.org/abs/2410.17897v1
1262
+ v = v + value_residual
1252
1263
 
1253
1264
  # attention is all we need
1254
1265
 
@@ -1259,10 +1270,13 @@ class Attention(Module):
1259
1270
  prev_attn = prev_attn
1260
1271
  )
1261
1272
 
1262
- # store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
1273
+ # store the values for resformer or Neutreno
1263
1274
 
1264
1275
  intermediates.values = v
1265
1276
 
1277
+ if exists(value_residual) and self.neutreno_value_residual:
1278
+ out = out + diff_values
1279
+
1266
1280
  # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1267
1281
 
1268
1282
  if exists(r):
@@ -1365,7 +1379,7 @@ class AttentionLayers(Module):
1365
1379
  layerscale_init_value = 0.,
1366
1380
  unet_skips = False,
1367
1381
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1368
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
1382
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1369
1383
  **kwargs
1370
1384
  ):
1371
1385
  super().__init__()
@@ -1378,6 +1392,7 @@ class AttentionLayers(Module):
1378
1392
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1379
1393
 
1380
1394
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1395
+ add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
1381
1396
 
1382
1397
  self.dim = dim
1383
1398
  self.causal = causal
@@ -1803,7 +1818,10 @@ class AttentionLayers(Module):
1803
1818
 
1804
1819
  skip_hiddens = []
1805
1820
 
1821
+ # for value residuals
1822
+
1806
1823
  first_self_attn_inter = None
1824
+ first_cross_attn_inter = None
1807
1825
 
1808
1826
  # go through the attention and feedforward layers
1809
1827
 
@@ -1856,20 +1874,35 @@ class AttentionLayers(Module):
1856
1874
 
1857
1875
  block = partial(block, **block_forward_kwargs)
1858
1876
 
1859
- maybe_value_residual = None
1860
- if self.add_value_residual and exists(first_self_attn_inter):
1861
- maybe_value_residual = first_self_attn_inter.values
1877
+ # handle maybe value residuals
1878
+
1879
+ maybe_self_attn_value_residual = None
1880
+ maybe_cross_attn_value_residual = None
1881
+
1882
+ if self.add_value_residual:
1883
+ if exists(first_self_attn_inter):
1884
+ maybe_self_attn_value_residual = first_self_attn_inter.values
1885
+
1886
+ if exists(first_cross_attn_inter):
1887
+ maybe_cross_attn_value_residual = first_cross_attn_inter.values
1888
+
1889
+ # forward depending on layer type
1862
1890
 
1863
1891
  if layer_type == 'a':
1864
- out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_value_residual, return_intermediates = True)
1892
+ out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_self_attn_value_residual, return_intermediates = True)
1865
1893
  elif layer_type == 'c':
1866
- out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
1894
+ out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), value_residual = maybe_cross_attn_value_residual, return_intermediates = True)
1867
1895
  elif layer_type == 'f':
1868
1896
  out = block(x)
1869
1897
 
1898
+ # store first self or cross attention intermediate for value residual
1899
+
1870
1900
  if not exists(first_self_attn_inter) and layer_type == 'a':
1871
1901
  first_self_attn_inter = inter
1872
1902
 
1903
+ if not exists(first_cross_attn_inter) and layer_type == 'c':
1904
+ first_cross_attn_inter = inter
1905
+
1873
1906
  if self.resi_dual:
1874
1907
  outer_residual = outer_residual + out * self.resi_dual_scale
1875
1908
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.7
3
+ Version: 1.40.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=azvG00DKeg4s0tH28QeDhU6X2GaeFjq3-0RMxgVt408,86715
8
+ x_transformers/x_transformers.py,sha256=NoWBoiz1t8_QytYo1T2YBFk-7H9s38k2t-EksxqUkMU,88072
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.7.dist-info/METADATA,sha256=mhcww_WPNU2q-piUEvoDEXwForJdgk9s_K6iCQVFDqo,661
13
- x_transformers-1.40.7.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.7.dist-info/RECORD,,
11
+ x_transformers-1.40.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.9.dist-info/METADATA,sha256=xSxqFkhGfr5dU2xI0xo3UzlPMSuaaR4Rd2TrDpEyxcE,661
13
+ x_transformers-1.40.9.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.9.dist-info/RECORD,,