x-transformers 2.0.1__py3-none-any.whl → 2.0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1204,6 +1204,7 @@ class Attention(Module):
1204
1204
  hybrid_module: Module | None = None,
1205
1205
  hybrid_mask_kwarg: str | None = None,
1206
1206
  hybrid_fold_axial_dim: int | None = None,
1207
+ hybrid_learned_mix = False,
1207
1208
  one_kv_head = False,
1208
1209
  kv_heads = None,
1209
1210
  value_dim_head = None,
@@ -1446,7 +1447,7 @@ class Attention(Module):
1446
1447
 
1447
1448
  if exists(hybrid_module) and exists(hybrid_fold_axial_dim):
1448
1449
  hybrid_module = FoldAxially(axial_dim = hybrid_fold_axial_dim, fn = hybrid_module)
1449
- hybrid_mix = LinearNoBias(dim, heads)
1450
+ hybrid_mix = LinearNoBias(dim, heads) if hybrid_learned_mix else None
1450
1451
 
1451
1452
  hybrid_norms = ModuleList([
1452
1453
  MultiheadRMSNorm(dim_head, heads = heads),
@@ -1779,7 +1780,12 @@ class Attention(Module):
1779
1780
  out = out_norm(out)
1780
1781
  hybrid_out = hybrid_out_norm(hybrid_out)
1781
1782
 
1782
- out = 0.5 * (out + hybrid_out)
1783
+ if exists(self.hybrid_mix):
1784
+ mix = self.hybrid_mix(x)
1785
+ mix = rearrange(mix, 'b n h -> b h n 1')
1786
+ out = out.lerp(hybrid_out, mix.sigmoid())
1787
+ else:
1788
+ out = 0.5 * (out + hybrid_out)
1783
1789
 
1784
1790
  # merge heads
1785
1791
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.0.1
3
+ Version: 2.0.2
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -40,6 +40,7 @@ Requires-Dist: loguru
40
40
  Requires-Dist: packaging>=21.0
41
41
  Requires-Dist: torch>=2.0
42
42
  Provides-Extra: examples
43
+ Requires-Dist: lion-pytorch; extra == 'examples'
43
44
  Requires-Dist: torchvision; extra == 'examples'
44
45
  Requires-Dist: tqdm; extra == 'examples'
45
46
  Provides-Extra: test
@@ -6,10 +6,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
8
8
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
9
- x_transformers/x_transformers.py,sha256=e-iIpcUJzDKU-_s4CVP9f7HoM2c11RH35-TH4G8pAYg,107251
9
+ x_transformers/x_transformers.py,sha256=1s8KCSfHXMN9TKLFdS-RzzCskBDkh4CuBk2_XRb6IXk,107537
10
10
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
11
11
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
12
- x_transformers-2.0.1.dist-info/METADATA,sha256=wygqpzU7ne5w7wj1GhPUEEuVROu0qIP0eDpjfaxlBe0,86457
13
- x_transformers-2.0.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
- x_transformers-2.0.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
- x_transformers-2.0.1.dist-info/RECORD,,
12
+ x_transformers-2.0.2.dist-info/METADATA,sha256=tNdI3H2S4HnnGK1hPY3l94FoXH3SB9vGAb55pcah6Yw,86506
13
+ x_transformers-2.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
14
+ x_transformers-2.0.2.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
15
+ x_transformers-2.0.2.dist-info/RECORD,,