x-transformers 1.31.11__py3-none-any.whl → 1.31.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +10 -1
- {x_transformers-1.31.11.dist-info → x_transformers-1.31.14.dist-info}/METADATA +1 -1
- {x_transformers-1.31.11.dist-info → x_transformers-1.31.14.dist-info}/RECORD +6 -6
- {x_transformers-1.31.11.dist-info → x_transformers-1.31.14.dist-info}/WHEEL +1 -1
- {x_transformers-1.31.11.dist-info → x_transformers-1.31.14.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.11.dist-info → x_transformers-1.31.14.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -597,6 +597,9 @@ class AdaptiveLayerNorm(Module):
|
|
597
597
|
nn.init.zeros_(self.to_gamma.weight)
|
598
598
|
|
599
599
|
def forward(self, x, *, condition):
|
600
|
+
if condition.ndim == 2:
|
601
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
602
|
+
|
600
603
|
normed = self.ln(x)
|
601
604
|
gamma = self.to_gamma(condition)
|
602
605
|
return normed * (gamma + 1.)
|
@@ -649,6 +652,9 @@ class AdaptiveRMSNorm(Module):
|
|
649
652
|
nn.init.zeros_(self.to_gamma.weight)
|
650
653
|
|
651
654
|
def forward(self, x, *, condition):
|
655
|
+
if condition.ndim == 2:
|
656
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
657
|
+
|
652
658
|
normed = F.normalize(x, dim = -1)
|
653
659
|
gamma = self.to_gamma(condition)
|
654
660
|
return normed * self.scale * (gamma + 1.)
|
@@ -775,6 +781,9 @@ class AdaptiveLayerScale(Module):
|
|
775
781
|
nn.init.constant_(self.to_gamma.bias, init_bias_value)
|
776
782
|
|
777
783
|
def forward(self, x, *, condition, **kwargs):
|
784
|
+
if condition.ndim == 2:
|
785
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
786
|
+
|
778
787
|
out = self.fn(x, **kwargs)
|
779
788
|
gamma = self.to_gamma(condition).sigmoid()
|
780
789
|
|
@@ -889,7 +898,7 @@ class Attention(Module):
|
|
889
898
|
cope_soft_onehot_pos = False,
|
890
899
|
cope_talking_heads = False,
|
891
900
|
softclamp_logits = False,
|
892
|
-
logit_softclamp_value =
|
901
|
+
logit_softclamp_value = 50.,
|
893
902
|
onnxable = False
|
894
903
|
):
|
895
904
|
super().__init__()
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.14.dist-info/METADATA,sha256=Qj5yRxhBmF87HtYeWuFTiiYVZf-eDdXabhW_P5McQ7w,662
|
12
|
+
x_transformers-1.31.14.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
|
13
|
+
x_transformers-1.31.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|