x-transformers 1.42.17__py3-none-any.whl → 1.42.19__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -20,6 +20,8 @@ import einx
20
20
  from einops.layers.torch import Rearrange
21
21
  from einops import rearrange, repeat, reduce, pack, unpack
22
22
 
23
+ from loguru import logger
24
+
23
25
  from x_transformers.attend import Attend, Intermediates
24
26
  from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
25
27
 
@@ -1073,6 +1075,7 @@ class Attention(Module):
1073
1075
  neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
1074
1076
  neutreno_alpha = 0.4,
1075
1077
  learned_value_residual_mix = False,
1078
+ laser = False, # https://arxiv.org/abs/2411.03493v1
1076
1079
  onnxable = False,
1077
1080
  attend_sdp_kwargs: dict = dict(
1078
1081
  enable_flash = True,
@@ -1112,6 +1115,11 @@ class Attention(Module):
1112
1115
  assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values'
1113
1116
  self.to_v = LinearNoBias(dim_kv, v_dim) if not shared_kv else None
1114
1117
 
1118
+ # enhancing gradients to attention through exponentiated values
1119
+ # todo - compare it to `attn = attn * large_value + attn.detach() * (1. - large_value)`
1120
+
1121
+ self.laser = laser
1122
+
1115
1123
  # relations projection from tp-attention
1116
1124
 
1117
1125
  self.to_r = LinearNoBias(dim, v_dim) if tensor_product else None
@@ -1437,6 +1445,11 @@ class Attention(Module):
1437
1445
 
1438
1446
  attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0))
1439
1447
 
1448
+ if self.laser:
1449
+ values_max = v.amax(dim = -2, keepdim = True).detach() # numerical stability
1450
+ v = v - values_max
1451
+ v = v.exp()
1452
+
1440
1453
  # attention is all we need
1441
1454
 
1442
1455
  out, intermediates = self.attend(
@@ -1446,6 +1459,11 @@ class Attention(Module):
1446
1459
  prev_attn = prev_attn
1447
1460
  )
1448
1461
 
1462
+ # laser
1463
+
1464
+ if self.laser:
1465
+ out = out.log() + values_max
1466
+
1449
1467
  # store the values for resformer or Neutreno
1450
1468
 
1451
1469
  intermediates.values = orig_values
@@ -1580,7 +1598,12 @@ class AttentionLayers(Module):
1580
1598
 
1581
1599
  self.disable_abs_pos_emb = default(disable_abs_pos_emb, (rel_pos_bias or rotary_pos_emb))
1582
1600
 
1583
- rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32)
1601
+ rotary_emb_dim = default(rotary_emb_dim, dim_head // 2)
1602
+
1603
+ assert rotary_emb_dim <= dim_head, f'rotary emb dim {rotary_emb_dim} must be less than or equal to attention head dimension {dim_head}'
1604
+
1605
+ if rotary_emb_dim < 32:
1606
+ logger.warning('when training language model, rotary embedding dimension should be at least 32')
1584
1607
 
1585
1608
  assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention'
1586
1609
  self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.42.17
3
+ Version: 1.42.19
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -14,7 +14,8 @@ Classifier: License :: OSI Approved :: MIT License
14
14
  Classifier: Programming Language :: Python :: 3.6
15
15
  Description-Content-Type: text/markdown
16
16
  License-File: LICENSE
17
- Requires-Dist: torch>=2.0
18
17
  Requires-Dist: einx>=0.3.0
19
18
  Requires-Dist: einops>=0.8.0
19
+ Requires-Dist: loguru
20
20
  Requires-Dist: packaging>=21.0
21
+ Requires-Dist: torch>=2.0
@@ -6,11 +6,11 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=Wvkw4j_78413LdCnCt_wHgcVFiCbzrC8u4TH2iXhkNU,95181
9
+ x_transformers/x_transformers.py,sha256=pDYtIGhoo-lFn_ULJETnQz1Z0QYuDsD4ReTlPy__jwo,95993
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-1.42.17.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
- x_transformers-1.42.17.dist-info/METADATA,sha256=T1MDXNdxqdPkqFpGuJVb7vBhniGCbHefm5C-lhb3LJk,717
14
- x_transformers-1.42.17.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
- x_transformers-1.42.17.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
- x_transformers-1.42.17.dist-info/RECORD,,
12
+ x_transformers-1.42.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
13
+ x_transformers-1.42.19.dist-info/METADATA,sha256=pJgi1Jp7FvM1o_x3a7uOaSJ8x0pNgIQnAp4lSI3K__o,739
14
+ x_transformers-1.42.19.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
15
+ x_transformers-1.42.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
16
+ x_transformers-1.42.19.dist-info/RECORD,,