x-transformers 1.40.8__py3-none-any.whl → 1.40.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -944,6 +944,8 @@ class Attention(Module):
944
944
  cope_talking_heads = False,
945
945
  softclamp_logits = False,
946
946
  logit_softclamp_value = 50.,
947
+ neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
948
+ neutreno_alpha = 0.4,
947
949
  onnxable = False
948
950
  ):
949
951
  super().__init__()
@@ -982,6 +984,11 @@ class Attention(Module):
982
984
 
983
985
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
984
986
 
987
+ # the value residual used by Nguyen et al. in https://arxiv.org/abs/2312.00751 for countering oversmoothing
988
+
989
+ self.neutreno_value_residual = neutreno_value_residual
990
+ self.neutreno_alpha = neutreno_alpha
991
+
985
992
  # add GLU gating for aggregated values, from alphafold2
986
993
 
987
994
  self.to_v_gate = None
@@ -1244,11 +1251,15 @@ class Attention(Module):
1244
1251
  attn_bias = rel_pos(i, j)
1245
1252
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values
1246
1253
 
1247
- # previous values passed in
1248
- # https://arxiv.org/abs/2410.17897v1
1254
+ # if previous values passed in for residual, either invoke resformer or neutreno
1249
1255
 
1250
1256
  if exists(value_residual):
1251
- v = v + value_residual
1257
+ if self.neutreno_value_residual:
1258
+ diff_values = (value_residual - v) * self.neutreno_alpha
1259
+ diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
1260
+ else:
1261
+ # https://arxiv.org/abs/2410.17897v1
1262
+ v = v + value_residual
1252
1263
 
1253
1264
  # attention is all we need
1254
1265
 
@@ -1259,10 +1270,13 @@ class Attention(Module):
1259
1270
  prev_attn = prev_attn
1260
1271
  )
1261
1272
 
1262
- # store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1
1273
+ # store the values for resformer or Neutreno
1263
1274
 
1264
1275
  intermediates.values = v
1265
1276
 
1277
+ if exists(value_residual) and self.neutreno_value_residual:
1278
+ out = out + diff_values
1279
+
1266
1280
  # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients
1267
1281
 
1268
1282
  if exists(r):
@@ -1365,7 +1379,7 @@ class AttentionLayers(Module):
1365
1379
  layerscale_init_value = 0.,
1366
1380
  unet_skips = False,
1367
1381
  reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
1368
- add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
1382
+ add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
1369
1383
  **kwargs
1370
1384
  ):
1371
1385
  super().__init__()
@@ -1378,6 +1392,7 @@ class AttentionLayers(Module):
1378
1392
  assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1379
1393
 
1380
1394
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1395
+ add_value_residual |= attn_kwargs.get('neutreno_value_residual', False)
1381
1396
 
1382
1397
  self.dim = dim
1383
1398
  self.causal = causal
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.8
3
+ Version: 1.40.9
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=UyJ3-XoanCHOUZLoyp29jt-yNWbgc7pFqpfjOGFFEPY,87367
8
+ x_transformers/x_transformers.py,sha256=NoWBoiz1t8_QytYo1T2YBFk-7H9s38k2t-EksxqUkMU,88072
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.8.dist-info/METADATA,sha256=epku2YLENdljdFU7ge7hhUDXMujquHyb0jhoDSv9yFk,661
13
- x_transformers-1.40.8.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.8.dist-info/RECORD,,
11
+ x_transformers-1.40.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.9.dist-info/METADATA,sha256=xSxqFkhGfr5dU2xI0xo3UzlPMSuaaR4Rd2TrDpEyxcE,661
13
+ x_transformers-1.40.9.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.9.dist-info/RECORD,,