x-transformers 2.0.1__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1204,6 +1204,7 @@ class Attention(Module):
1204
1204
  hybrid_module: Module | None = None,
1205
1205
  hybrid_mask_kwarg: str | None = None,
1206
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1207
1208
  one_kv_head = False,
1208
1209
  kv_heads = None,
1209
1210
  value_dim_head = None,
@@ -1446,7 +1447,7 @@ class Attention(Module):
1446
1447
 
1447
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1448
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
- hybrid_mix = LinearNoBias(dim, heads)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1450
1451
 
1451
1452
  hybrid_norms = ModuleList([
1452
1453
  MultiheadRMSNorm(dim_head, heads = heads),
@@ -1779,7 +1780,12 @@ class Attention(Module):
1779
1780
  out = out_norm(out)
1780
1781
  hybrid_out = hybrid_out_norm(hybrid_out)
1781
1782
 
1782
- out = 0.5 * (out + hybrid_out)
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1783
1789
 
1784
1790
  # merge heads
1785
1791
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.1
3
+ Version: 2.0.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -40,6 +40,7 @@ Requires-Dist: loguru
40
40
  Requires-Dist: packaging>=21.0
41
41
  Requires-Dist: torch>=2.0
42
42
  Provides-Extra: examples
43
+ Requires-Dist: lion-pytorch; extra == 'examples'
43
44
  Requires-Dist: torchvision; extra == 'examples'
44
45
  Requires-Dist: tqdm; extra == 'examples'
45
46
  Provides-Extra: test
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=e-iIpcUJzDKU-_s4CVP9f7HoM2c11RH35-TH4G8pAYg,107251
9
+ x_transformers/x_transformers.py,sha256=1s8KCSfHXMN9TKLFdS-RzzCskBDkh4CuBk2_XRb6IXk,107537
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.0.1.dist-info/METADATA,sha256=wygqpzU7ne5w7wj1GhPUEEuVROu0qIP0eDpjfaxlBe0,86457
13
- x_transformers-2.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.0.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.0.1.dist-info/RECORD,,
12
+ x_transformers-2.0.2.dist-info/METADATA,sha256=tNdI3H2S4HnnGK1hPY3l94FoXH3SB9vGAb55pcah6Yw,86506
13
+ x_transformers-2.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ x_transformers-2.0.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
+ x_transformers-2.0.2.dist-info/RECORD,,