x-transformers 1.30.17__py3-none-any.whl → 1.30.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +19 -1
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.18.dist-info}/METADATA +1 -1
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.18.dist-info}/RECORD +6 -6
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.18.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.18.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.18.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -607,6 +607,20 @@ class RMSNorm(Module):
|
|
607
607
|
def forward(self, x):
|
608
608
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
609
609
|
|
610
|
+
class AdaptiveRMSNorm(Module):
|
611
|
+
def __init__(self, dim, dim_condition = None):
|
612
|
+
super().__init__()
|
613
|
+
self.scale = dim ** 0.5
|
614
|
+
dim_condition = default(dim_condition, dim)
|
615
|
+
|
616
|
+
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
617
|
+
nn.init.zeros_(self.to_gamma.weight)
|
618
|
+
|
619
|
+
def forward(self, x, condition):
|
620
|
+
normed = F.normalize(x, dim = -1)
|
621
|
+
gamma = self.to_gamma(condition)
|
622
|
+
return normed * self.scale * (gamma + 1.)
|
623
|
+
|
610
624
|
class SimpleRMSNorm(Module):
|
611
625
|
def __init__(self, dim):
|
612
626
|
super().__init__()
|
@@ -1145,6 +1159,7 @@ class AttentionLayers(Module):
|
|
1145
1159
|
use_rmsnorm = False,
|
1146
1160
|
use_simple_rmsnorm = False,
|
1147
1161
|
use_adaptive_layernorm = False,
|
1162
|
+
use_adaptive_rmsnorm = False,
|
1148
1163
|
dim_condition = None,
|
1149
1164
|
alibi_pos_bias = False,
|
1150
1165
|
alibi_num_heads = None,
|
@@ -1250,7 +1265,7 @@ class AttentionLayers(Module):
|
|
1250
1265
|
|
1251
1266
|
# determine norm
|
1252
1267
|
|
1253
|
-
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm)
|
1268
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
1254
1269
|
|
1255
1270
|
need_condition = False
|
1256
1271
|
dim_condition = default(dim_condition, dim)
|
@@ -1264,6 +1279,9 @@ class AttentionLayers(Module):
|
|
1264
1279
|
elif use_adaptive_layernorm:
|
1265
1280
|
need_condition = True
|
1266
1281
|
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1282
|
+
elif use_adaptive_rmsnorm:
|
1283
|
+
need_condition = True
|
1284
|
+
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
|
1267
1285
|
else:
|
1268
1286
|
norm_class = LayerNorm
|
1269
1287
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
|
12
|
+
x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|