x-transformers 1.42.18__py3-none-any.whl → 1.42.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
x_transformers/attend.py CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
370
370
  # convert from bool to float
371
371
 
372
372
  if exists(attn_bias):
373
- attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)
373
+ attn_bias = attn_bias.expand(batch, heads, -1, -1)
374
374
 
375
375
  # if mask given, the mask would already contain the causal mask from above logic
376
376
  # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
@@ -1075,6 +1075,7 @@ class Attention(Module):
1075
1075
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1076
1076
  neutreno_alpha = 0.4,
1077
1077
  learned_value_residual_mix = False,
1078
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1078
1079
  onnxable = False,
1079
1080
  attend_sdp_kwargs: dict = dict(
1080
1081
  enable_flash = True,
@@ -1114,6 +1115,11 @@ class Attention(Module):
1114
1115
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1115
1116
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1116
1117
 
1118
+ # enhancing gradients to attention through exponentiated values
1119
+ # todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
1120
+
1121
+ self.laser = laser
1122
+
1117
1123
  # relations projection from tp-attention
1118
1124
 
1119
1125
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
@@ -1439,6 +1445,11 @@ class Attention(Module):
1439
1445
 
1440
1446
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
1441
1447
 
1448
+ if self.laser:
1449
+ values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
1450
+ v = v - values_max
1451
+ v = v.exp()
1452
+
1442
1453
  # attention is all we need
1443
1454
 
1444
1455
  out, intermediates = self.attend(
@@ -1448,6 +1459,11 @@ class Attention(Module):
1448
1459
  prev_attn = prev_attn
1449
1460
  )
1450
1461
 
1462
+ # laser
1463
+
1464
+ if self.laser:
1465
+ out = out.log() + values_max
1466
+
1451
1467
  # store the values for resformer or Neutreno
1452
1468
 
1453
1469
  intermediates.values = orig_values
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.18
3
+ Version: 1.42.20
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,16 +1,16 @@
1
1
  x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
2
- x_transformers/attend.py,sha256=SdWlV8Vp5DtpsOzAd0LRhm4VGrJf0lJCGiV2_j_CtoA,17284
2
+ x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
3
3
  x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
4
4
  x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
5
5
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=pFVTmAoAbrir7YjTwzC3X2buRSm7PFnWqYyTYePA8Es,95486
9
+ x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.18.dist-info/METADATA,sha256=v9YlgCULHqvWhTC3bViadNngzfiyYkzrQa6XRZ0uDa4,739
14
- x_transformers-1.42.18.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.18.dist-info/RECORD,,
12
+ x_transformers-1.42.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.20.dist-info/METADATA,sha256=J0yBEg7oUfbkJaC3WxfB9Oq4XbGxXA5VjUGd9AHELGk,739
14
+ x_transformers-1.42.20.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.20.dist-info/RECORD,,