x-transformers 1.31.5__py3-none-any.whl → 1.31.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +9 -5
- {x_transformers-1.31.5.dist-info → x_transformers-1.31.7.dist-info}/METADATA +1 -1
- {x_transformers-1.31.5.dist-info → x_transformers-1.31.7.dist-info}/RECORD +6 -6
- {x_transformers-1.31.5.dist-info → x_transformers-1.31.7.dist-info}/WHEEL +1 -1
- {x_transformers-1.31.5.dist-info → x_transformers-1.31.7.dist-info}/LICENSE +0 -0
- {x_transformers-1.31.5.dist-info → x_transformers-1.31.7.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -580,7 +580,7 @@ class LayerNorm(Module):
|
|
580
580
|
|
581
581
|
def forward(self, x):
|
582
582
|
normed = self.ln(x)
|
583
|
-
gamma = self.gamma + self.unit_offset
|
583
|
+
gamma = self.gamma + float(self.unit_offset)
|
584
584
|
return normed * gamma
|
585
585
|
|
586
586
|
class AdaptiveLayerNorm(Module):
|
@@ -615,7 +615,8 @@ class ScaleNorm(Module):
|
|
615
615
|
nn.init.constant_(self.g, 1. - float(unit_offset))
|
616
616
|
|
617
617
|
def forward(self, x):
|
618
|
-
|
618
|
+
gamma = self.g + float(self.unit_offset)
|
619
|
+
return F.normalize(x, dim = -1) * self.scale * gamma
|
619
620
|
|
620
621
|
class RMSNorm(Module):
|
621
622
|
def __init__(
|
@@ -631,7 +632,8 @@ class RMSNorm(Module):
|
|
631
632
|
nn.init.constant_(self.g, 1. - float(unit_offset))
|
632
633
|
|
633
634
|
def forward(self, x):
|
634
|
-
|
635
|
+
gamma = self.g + float(self.unit_offset)
|
636
|
+
return F.normalize(x, dim = -1) * self.scale * gamma
|
635
637
|
|
636
638
|
class AdaptiveRMSNorm(Module):
|
637
639
|
def __init__(
|
@@ -1267,7 +1269,8 @@ class AttentionLayers(Module):
|
|
1267
1269
|
scale_residual_constant = 1.,
|
1268
1270
|
shift_tokens = 0,
|
1269
1271
|
sandwich_norm = False,
|
1270
|
-
|
1272
|
+
softclamp_output = False,
|
1273
|
+
softclamp_output_value = 50.,
|
1271
1274
|
resi_dual = False,
|
1272
1275
|
resi_dual_scale = 1.,
|
1273
1276
|
zero_init_branch_output = False,
|
@@ -1484,6 +1487,7 @@ class AttentionLayers(Module):
|
|
1484
1487
|
# optional soft clamping just before the final norm
|
1485
1488
|
# used in gemma 2
|
1486
1489
|
|
1490
|
+
self.softclamp_output = softclamp_output
|
1487
1491
|
self.softclamp_output_value = softclamp_output_value
|
1488
1492
|
|
1489
1493
|
# whether it has post norm
|
@@ -1717,7 +1721,7 @@ class AttentionLayers(Module):
|
|
1717
1721
|
if return_hiddens:
|
1718
1722
|
layer_hiddens.append(x)
|
1719
1723
|
|
1720
|
-
if
|
1724
|
+
if self.softclamp_output:
|
1721
1725
|
x = softclamp(x, self.softclamp_output_value)
|
1722
1726
|
|
1723
1727
|
final_norm = self.final_norm
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=xvomb5imna2kCG_Kp-PQYsA6JGyiTx_1Dx5cD-YDlH4,75986
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.31.
|
11
|
-
x_transformers-1.31.
|
12
|
-
x_transformers-1.31.
|
13
|
-
x_transformers-1.31.
|
14
|
-
x_transformers-1.31.
|
10
|
+
x_transformers-1.31.7.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.31.7.dist-info/METADATA,sha256=Z3FjZ-v02tRiDNEYG5Bqw6Yg_gjzNRBKeCWxGrmib84,661
|
12
|
+
x_transformers-1.31.7.dist-info/WHEEL,sha256=y4mX-SOX4fYIkonsAGA5N0Oy-8_gI4FXw5HNI1xqvWg,91
|
13
|
+
x_transformers-1.31.7.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.31.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|