x-transformers 1.31.11__py3-none-any.whl → 1.31.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -597,6 +597,9 @@ class AdaptiveLayerNorm(Module):
597
597
  nn.init.zeros_(self.to_gamma.weight)
598
598
 
599
599
  def forward(self, x, *, condition):
600
+ if condition.ndim == 2:
601
+ condition = rearrange(condition, 'b d -> b 1 d')
602
+
600
603
  normed = self.ln(x)
601
604
  gamma = self.to_gamma(condition)
602
605
  return normed * (gamma + 1.)
@@ -649,6 +652,9 @@ class AdaptiveRMSNorm(Module):
649
652
  nn.init.zeros_(self.to_gamma.weight)
650
653
 
651
654
  def forward(self, x, *, condition):
655
+ if condition.ndim == 2:
656
+ condition = rearrange(condition, 'b d -> b 1 d')
657
+
652
658
  normed = F.normalize(x, dim = -1)
653
659
  gamma = self.to_gamma(condition)
654
660
  return normed * self.scale * (gamma + 1.)
@@ -775,6 +781,9 @@ class AdaptiveLayerScale(Module):
775
781
  nn.init.constant_(self.to_gamma.bias, init_bias_value)
776
782
 
777
783
  def forward(self, x, *, condition, **kwargs):
784
+ if condition.ndim == 2:
785
+ condition = rearrange(condition, 'b d -> b 1 d')
786
+
778
787
  out = self.fn(x, **kwargs)
779
788
  gamma = self.to_gamma(condition).sigmoid()
780
789
 
@@ -889,7 +898,7 @@ class Attention(Module):
889
898
  cope_soft_onehot_pos = False,
890
899
  cope_talking_heads = False,
891
900
  softclamp_logits = False,
892
- logit_softclamp_value = 30.,
901
+ logit_softclamp_value = 50.,
893
902
  onnxable = False
894
903
  ):
895
904
  super().__init__()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.11
3
+ Version: 1.31.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=qGZ67jBeynItbfgnKd5g2VNxUFSCpx9fy5A8zN6wMeg,76030
7
+ x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.11.dist-info/METADATA,sha256=GPkfjCnpqy9vpEIuqxtqyfKUOKEO9Pf__rp77THGqok,662
12
- x_transformers-1.31.11.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
13
- x_transformers-1.31.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.11.dist-info/RECORD,,
10
+ x_transformers-1.31.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.14.dist-info/METADATA,sha256=Qj5yRxhBmF87HtYeWuFTiiYVZf-eDdXabhW_P5McQ7w,662
12
+ x_transformers-1.31.14.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
13
+ x_transformers-1.31.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.14.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.3.0)
2
+ Generator: setuptools (71.0.3)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5