x-transformers 1.30.17__py3-none-any.whl → 1.30.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +44 -4
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.19.dist-info}/METADATA +1 -1
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.19.dist-info}/RECORD +6 -6
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.19.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.19.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.17.dist-info → x_transformers-1.30.19.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
|
|
584
584
|
|
585
585
|
nn.init.zeros_(self.to_gamma.weight)
|
586
586
|
|
587
|
-
def forward(self, x, condition):
|
587
|
+
def forward(self, x, *, condition):
|
588
588
|
normed = self.ln(x)
|
589
589
|
gamma = self.to_gamma(condition)
|
590
590
|
return normed * (gamma + 1.)
|
@@ -607,6 +607,20 @@ class RMSNorm(Module):
|
|
607
607
|
def forward(self, x):
|
608
608
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
609
609
|
|
610
|
+
class AdaptiveRMSNorm(Module):
|
611
|
+
def __init__(self, dim, dim_condition = None):
|
612
|
+
super().__init__()
|
613
|
+
self.scale = dim ** 0.5
|
614
|
+
dim_condition = default(dim_condition, dim)
|
615
|
+
|
616
|
+
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
617
|
+
nn.init.zeros_(self.to_gamma.weight)
|
618
|
+
|
619
|
+
def forward(self, x, *, condition):
|
620
|
+
normed = F.normalize(x, dim = -1)
|
621
|
+
gamma = self.to_gamma(condition)
|
622
|
+
return normed * self.scale * (gamma + 1.)
|
623
|
+
|
610
624
|
class SimpleRMSNorm(Module):
|
611
625
|
def __init__(self, dim):
|
612
626
|
super().__init__()
|
@@ -1145,7 +1159,10 @@ class AttentionLayers(Module):
|
|
1145
1159
|
use_rmsnorm = False,
|
1146
1160
|
use_simple_rmsnorm = False,
|
1147
1161
|
use_adaptive_layernorm = False,
|
1162
|
+
use_adaptive_rmsnorm = False,
|
1148
1163
|
dim_condition = None,
|
1164
|
+
adaptive_condition_mlp = False,
|
1165
|
+
adaptive_condition_mlp_expansion = 4,
|
1149
1166
|
alibi_pos_bias = False,
|
1150
1167
|
alibi_num_heads = None,
|
1151
1168
|
rel_pos_bias = False,
|
@@ -1250,11 +1267,22 @@ class AttentionLayers(Module):
|
|
1250
1267
|
|
1251
1268
|
# determine norm
|
1252
1269
|
|
1253
|
-
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm)
|
1270
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
1254
1271
|
|
1255
1272
|
need_condition = False
|
1256
1273
|
dim_condition = default(dim_condition, dim)
|
1257
1274
|
|
1275
|
+
self.adaptive_mlp = nn.Identity()
|
1276
|
+
dim_condition_mult = 1
|
1277
|
+
|
1278
|
+
if adaptive_condition_mlp:
|
1279
|
+
dim_condition_mult = adaptive_condition_mlp_expansion
|
1280
|
+
|
1281
|
+
self.adaptive_mlp = nn.Sequential(
|
1282
|
+
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1283
|
+
nn.SiLU()
|
1284
|
+
)
|
1285
|
+
|
1258
1286
|
if use_scalenorm:
|
1259
1287
|
norm_class = ScaleNorm
|
1260
1288
|
elif use_rmsnorm:
|
@@ -1263,7 +1291,10 @@ class AttentionLayers(Module):
|
|
1263
1291
|
norm_class = SimpleRMSNorm
|
1264
1292
|
elif use_adaptive_layernorm:
|
1265
1293
|
need_condition = True
|
1266
|
-
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1294
|
+
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
1295
|
+
elif use_adaptive_rmsnorm:
|
1296
|
+
need_condition = True
|
1297
|
+
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
1267
1298
|
else:
|
1268
1299
|
norm_class = LayerNorm
|
1269
1300
|
|
@@ -1428,7 +1459,16 @@ class AttentionLayers(Module):
|
|
1428
1459
|
norm_kwargs = dict()
|
1429
1460
|
|
1430
1461
|
if self.need_condition:
|
1431
|
-
assert condition.
|
1462
|
+
assert condition.ndim in {2, 3}
|
1463
|
+
|
1464
|
+
if condition.ndim == 2:
|
1465
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
1466
|
+
|
1467
|
+
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
1468
|
+
|
1469
|
+
# maybe mlp
|
1470
|
+
|
1471
|
+
condition = self.adaptive_mlp(condition)
|
1432
1472
|
|
1433
1473
|
norm_kwargs.update(condition = condition)
|
1434
1474
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=NxNSAMgyHZEE-toXNcEO-unGwKfoqooN9oEQMDPPGfY,72268
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.19.dist-info/METADATA,sha256=2fQP9SmEys1grPsY2xHhB3sqkODtlaIz4UWtOR9pCM8,662
|
12
|
+
x_transformers-1.30.19.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|