x-transformers 1.42.18__py3-none-any.whl → 1.42.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/attend.py +1 -1
- x_transformers/x_transformers.py +16 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/METADATA +1 -1
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/RECORD +7 -7
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
x_transformers/x_transformers.py
CHANGED
@@ -1075,6 +1075,7 @@ class Attention(Module):
|
|
1075
1075
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1076
1076
|
neutreno_alpha = 0.4,
|
1077
1077
|
learned_value_residual_mix = False,
|
1078
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1078
1079
|
onnxable = False,
|
1079
1080
|
attend_sdp_kwargs: dict = dict(
|
1080
1081
|
enable_flash = True,
|
@@ -1114,6 +1115,11 @@ class Attention(Module):
|
|
1114
1115
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
1115
1116
|
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
1116
1117
|
|
1118
|
+
# enhancing gradients to attention through exponentiated values
|
1119
|
+
# todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
|
1120
|
+
|
1121
|
+
self.laser = laser
|
1122
|
+
|
1117
1123
|
# relations projection from tp-attention
|
1118
1124
|
|
1119
1125
|
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
@@ -1439,6 +1445,11 @@ class Attention(Module):
|
|
1439
1445
|
|
1440
1446
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
1441
1447
|
|
1448
|
+
if self.laser:
|
1449
|
+
values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
|
1450
|
+
v = v - values_max
|
1451
|
+
v = v.exp()
|
1452
|
+
|
1442
1453
|
# attention is all we need
|
1443
1454
|
|
1444
1455
|
out, intermediates = self.attend(
|
@@ -1448,6 +1459,11 @@ class Attention(Module):
|
|
1448
1459
|
prev_attn = prev_attn
|
1449
1460
|
)
|
1450
1461
|
|
1462
|
+
# laser
|
1463
|
+
|
1464
|
+
if self.laser:
|
1465
|
+
out = out.log() + values_max
|
1466
|
+
|
1451
1467
|
# store the values for resformer or Neutreno
|
1452
1468
|
|
1453
1469
|
intermediates.values = orig_values
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
|
-
x_transformers/attend.py,sha256
|
2
|
+
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.20.dist-info/METADATA,sha256=J0yBEg7oUfbkJaC3WxfB9Oq4XbGxXA5VjUGd9AHELGk,739
|
14
|
+
x_transformers-1.42.20.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|