x-transformers 1.42.18__py3-none-any.whl → 1.42.20__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
x_transformers/attend.py CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
370
370
  # convert from bool to float
371
371
 
372
372
  if exists(attn_bias):
373
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
373
+ attn_bias = attn_bias.expand(batch, heads, -1, -1)
374
374
 
375
375
  # if mask given, the mask would already contain the causal mask from above logic
376
376
  # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
@@ -1075,6 +1075,7 @@ class Attention(Module):
1075
1075
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1076
1076
  neutreno_alpha = 0.4,
1077
1077
  learned_value_residual_mix = False,
1078
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1078
1079
  onnxable = False,
1079
1080
  attend_sdp_kwargs: dict = dict(
1080
1081
  enable_flash = True,
@@ -1114,6 +1115,11 @@ class Attention(Module):
1114
1115
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1115
1116
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1116
1117
 
1118
+ # enhancing gradients to attention through exponentiated values
1119
+ # todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
1120
+
1121
+ self.laser = laser
1122
+
1117
1123
  # relations projection from tp-attention
1118
1124
 
1119
1125
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
@@ -1439,6 +1445,11 @@ class Attention(Module):
1439
1445
 
1440
1446
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
1441
1447
 
1448
+ if self.laser:
1449
+ values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
1450
+ v = v - values_max
1451
+ v = v.exp()
1452
+
1442
1453
  # attention is all we need
1443
1454
 
1444
1455
  out, intermediates = self.attend(
@@ -1448,6 +1459,11 @@ class Attention(Module):
1448
1459
  prev_attn = prev_attn
1449
1460
  )
1450
1461
 
1462
+ # laser
1463
+
1464
+ if self.laser:
1465
+ out = out.log() + values_max
1466
+
1451
1467
  # store the values for resformer or Neutreno
1452
1468
 
1453
1469
  intermediates.values = orig_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.18
3
+ Version: 1.42.20
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
- x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
2
+ x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=pFVTmAoAbrir7YjTwzC3X2buRSm7PFnWqYyTYePA8Es,95486
9
+ x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.18.dist-info/METADATA,sha256=v9YlgCULHqvWhTC3bViadNngzfiyYkzrQa6XRZ0uDa4,739
14
- x_transformers-1.42.18.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.18.dist-info/RECORD,,
12
+ x_transformers-1.42.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.20.dist-info/METADATA,sha256=J0yBEg7oUfbkJaC3WxfB9Oq4XbGxXA5VjUGd9AHELGk,739
14
+ x_transformers-1.42.20.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.20.dist-info/RECORD,,