x-transformers 1.42.18__py3-none-any.whl → 1.42.20__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- x_transformers/attend.py +1 -1
- x_transformers/x_transformers.py +16 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/METADATA +1 -1
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/RECORD +7 -7
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/LICENSE +0 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/WHEEL +0 -0
- {x_transformers-1.42.18.dist-info → x_transformers-1.42.20.dist-info}/top_level.txt +0 -0
x_transformers/attend.py
CHANGED
@@ -370,7 +370,7 @@ class Attend(Module):
|
|
370
370
|
# convert from bool to float
|
371
371
|
|
372
372
|
if exists(attn_bias):
|
373
|
-
attn_bias =
|
373
|
+
attn_bias = attn_bias.expand(batch, heads, -1, -1)
|
374
374
|
|
375
375
|
# if mask given, the mask would already contain the causal mask from above logic
|
376
376
|
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number
|
x_transformers/x_transformers.py
CHANGED
@@ -1075,6 +1075,7 @@ class Attention(Module):
|
|
1075
1075
|
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
|
1076
1076
|
neutreno_alpha = 0.4,
|
1077
1077
|
learned_value_residual_mix = False,
|
1078
|
+
laser = False, # https://arxiv.org/abs/2411.03493v1
|
1078
1079
|
onnxable = False,
|
1079
1080
|
attend_sdp_kwargs: dict = dict(
|
1080
1081
|
enable_flash = True,
|
@@ -1114,6 +1115,11 @@ class Attention(Module):
|
|
1114
1115
|
assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
|
1115
1116
|
self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
|
1116
1117
|
|
1118
|
+
# enhancing gradients to attention through exponentiated values
|
1119
|
+
# todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
|
1120
|
+
|
1121
|
+
self.laser = laser
|
1122
|
+
|
1117
1123
|
# relations projection from tp-attention
|
1118
1124
|
|
1119
1125
|
self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
|
@@ -1439,6 +1445,11 @@ class Attention(Module):
|
|
1439
1445
|
|
1440
1446
|
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
|
1441
1447
|
|
1448
|
+
if self.laser:
|
1449
|
+
values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
|
1450
|
+
v = v - values_max
|
1451
|
+
v = v.exp()
|
1452
|
+
|
1442
1453
|
# attention is all we need
|
1443
1454
|
|
1444
1455
|
out, intermediates = self.attend(
|
@@ -1448,6 +1459,11 @@ class Attention(Module):
|
|
1448
1459
|
prev_attn = prev_attn
|
1449
1460
|
)
|
1450
1461
|
|
1462
|
+
# laser
|
1463
|
+
|
1464
|
+
if self.laser:
|
1465
|
+
out = out.log() + values_max
|
1466
|
+
|
1451
1467
|
# store the values for resformer or Neutreno
|
1452
1468
|
|
1453
1469
|
intermediates.values = orig_values
|
@@ -1,16 +1,16 @@
|
|
1
1
|
x_transformers/__init__.py,sha256=l0dom8ZYkRzFvnDdgzDboXqrI1tKav3beVE7TN2nHko,844
|
2
|
-
x_transformers/attend.py,sha256
|
2
|
+
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/continuous.py,sha256=p0sCAiH1na236ygwgL1Yyhu36eZBf9cZvoW1JyP_fFE,7073
|
5
5
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
8
8
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
9
|
-
x_transformers/x_transformers.py,sha256=
|
9
|
+
x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
|
10
10
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
11
11
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
12
|
-
x_transformers-1.42.
|
13
|
-
x_transformers-1.42.
|
14
|
-
x_transformers-1.42.
|
15
|
-
x_transformers-1.42.
|
16
|
-
x_transformers-1.42.
|
12
|
+
x_transformers-1.42.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
13
|
+
x_transformers-1.42.20.dist-info/METADATA,sha256=J0yBEg7oUfbkJaC3WxfB9Oq4XbGxXA5VjUGd9AHELGk,739
|
14
|
+
x_transformers-1.42.20.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
15
|
+
x_transformers-1.42.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
16
|
+
x_transformers-1.42.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|