x-transformers 2.4.11__py3-none-any.whl → 2.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -188,7 +188,7 @@ class AutoregressiveWrapper(Module):
188
188
  temperature = 1.,
189
189
  stochastic = False,
190
190
  prompt_lens: Tensor | None = None,
191
- filter_logits_fn: str | Callable = top_k,
191
+ filter_logits_fn: str | Callable = identity,
192
192
  restrict_to_max_seq_len = True,
193
193
  filter_kwargs: dict = dict(),
194
194
  cache_kv = True,
@@ -1304,6 +1304,7 @@ class Attention(Module):
1304
1304
  qk_norm_groups = 1,
1305
1305
  qk_norm_scale = 10,
1306
1306
  qk_norm_dim_scale = False,
1307
+ value_rmsnorm = False, # used in alphagenome and bytedance's GR3 for further stability
1307
1308
  l2_distance = False,
1308
1309
  sigmoid = False,
1309
1310
  selective = False,
@@ -1458,6 +1459,10 @@ class Attention(Module):
1458
1459
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1459
1460
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
1460
1461
 
1462
+ # value rms norm
1463
+
1464
+ self.value_rmsnorm = MultiheadRMSNorm(dim_head, heads = heads) if value_rmsnorm else None
1465
+
1461
1466
  # contextual positional encoding
1462
1467
  # https://arxiv.org/html/2405.18719v2
1463
1468
 
@@ -1697,6 +1702,10 @@ class Attention(Module):
1697
1702
  q = q * self.qk_norm_q_scale
1698
1703
  k = k * self.qk_norm_k_scale
1699
1704
 
1705
+ # maybe value rmsnorm
1706
+
1707
+ v = maybe(self.value_rmsnorm)(v)
1708
+
1700
1709
  # take care of caching
1701
1710
 
1702
1711
  if not is_multi_latent_attn:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.4.11
3
+ Version: 2.4.14
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -1,6 +1,6 @@
1
1
  x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,985
2
2
  x_transformers/attend.py,sha256=Ax34Rw56xXAWfFPqtZ_B8iKEW2EfQdbVoc9uFjfeNjA,17404
3
- x_transformers/autoregressive_wrapper.py,sha256=n8ueNBMvIjO4B1J7VvSyDzJvqUi9YmCrri1p44n-FTY,17831
3
+ x_transformers/autoregressive_wrapper.py,sha256=y798kS9_VvPOY_5Ilits_64aXNqYvGuilsky1y07ryE,17834
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
5
  x_transformers/continuous.py,sha256=hpb1sSbt3k2LNzzjrjSd8F5xOIbKj7IluV9MBEAFLkw,13031
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
@@ -9,10 +9,10 @@ x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg
9
9
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
10
10
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
11
11
  x_transformers/up_wrapper.py,sha256=YC2LN14_7Xx9Wtiek2rtEJ_qHqdfSmKlh3d7Cgxwd80,7073
12
- x_transformers/x_transformers.py,sha256=9Fi0HvzpeIJqM6HlAd2M6JqsfjhTN1zEH9iFIimyjS4,117608
12
+ x_transformers/x_transformers.py,sha256=bNp6hWuuqn7x5yKFfYocvu3X1YCjpfwrWMh-kAanS48,117906
13
13
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
14
14
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
15
- x_transformers-2.4.11.dist-info/METADATA,sha256=N0EjJyBQ_2EjiQRJK-Rlvt7lzCn8XqFWXFiUyqUDwU8,90224
16
- x_transformers-2.4.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
- x_transformers-2.4.11.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
- x_transformers-2.4.11.dist-info/RECORD,,
15
+ x_transformers-2.4.14.dist-info/METADATA,sha256=KScRZIcmRXCv8NnhzQ3Uzo9uHE2oI51chzj78Wh_OVo,90224
16
+ x_transformers-2.4.14.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
17
+ x_transformers-2.4.14.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
18
+ x_transformers-2.4.14.dist-info/RECORD,,