x-transformers 1.31.10__py3-none-any.whl → 1.31.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,6 +7,7 @@ from x_transformers.x_transformers import (
7
7
  Attention,
8
8
  FeedForward,
9
9
  RMSNorm,
10
+ AdaptiveRMSNorm,
10
11
  TransformerWrapper,
11
12
  ViTransformerWrapper
12
13
  )
@@ -597,6 +597,9 @@ class AdaptiveLayerNorm(Module):
597
597
  nn.init.zeros_(self.to_gamma.weight)
598
598
 
599
599
  def forward(self, x, *, condition):
600
+ if condition.ndim == 2:
601
+ condition = rearrange(condition, 'b d -> b 1 d')
602
+
600
603
  normed = self.ln(x)
601
604
  gamma = self.to_gamma(condition)
602
605
  return normed * (gamma + 1.)
@@ -649,6 +652,9 @@ class AdaptiveRMSNorm(Module):
649
652
  nn.init.zeros_(self.to_gamma.weight)
650
653
 
651
654
  def forward(self, x, *, condition):
655
+ if condition.ndim == 2:
656
+ condition = rearrange(condition, 'b d -> b 1 d')
657
+
652
658
  normed = F.normalize(x, dim = -1)
653
659
  gamma = self.to_gamma(condition)
654
660
  return normed * self.scale * (gamma + 1.)
@@ -775,6 +781,9 @@ class AdaptiveLayerScale(Module):
775
781
  nn.init.constant_(self.to_gamma.bias, init_bias_value)
776
782
 
777
783
  def forward(self, x, *, condition, **kwargs):
784
+ if condition.ndim == 2:
785
+ condition = rearrange(condition, 'b d -> b 1 d')
786
+
778
787
  out = self.fn(x, **kwargs)
779
788
  gamma = self.to_gamma(condition).sigmoid()
780
789
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.31.10
3
+ Version: 1.31.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -1,14 +1,14 @@
1
- x_transformers/__init__.py,sha256=8LQl-dNL6vj8VHRx5LMSOlRDTXQvYOuM21PDXz8WdiI,703
1
+ x_transformers/__init__.py,sha256=5ms39Df8osTUHQ-XTCgP4vSUA4UiNpim9VXJtrLrIvQ,724
2
2
  x_transformers/attend.py,sha256=oAS0vSy5qH7iTCXzHKfM4k7m_fvuZIR49PStZO8OFJo,11089
3
3
  x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3nPqO6xp2FFCc,9619
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=qGZ67jBeynItbfgnKd5g2VNxUFSCpx9fy5A8zN6wMeg,76030
7
+ x_transformers/x_transformers.py,sha256=re8z4-kI3kukKSmUBZZJFjT0VvGvP6otksWFD3zw1kc,76312
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.31.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.31.10.dist-info/METADATA,sha256=U15CT3ilR-FAFQHrT5Gc92JWSZxgXIlGO1SLQFBdTAY,662
12
- x_transformers-1.31.10.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
13
- x_transformers-1.31.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.31.10.dist-info/RECORD,,
10
+ x_transformers-1.31.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.31.12.dist-info/METADATA,sha256=gqyQWsQrsE10Qg0ZU-gdnAa-FSpLAyhKgmiDaS-N6IQ,662
12
+ x_transformers-1.31.12.dist-info/WHEEL,sha256=Z4pYXqR_rTB7OWNDYFOm1qRk0RX6GFP2o8LgvP453Hk,91
13
+ x_transformers-1.31.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.31.12.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (70.2.0)
2
+ Generator: setuptools (70.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5