x-transformers 1.30.18__py3-none-any.whl → 1.30.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +27 -5
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.19.dist-info}/METADATA +1 -1
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.19.dist-info}/RECORD +6 -6
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.19.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.19.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.19.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
|
|
584
584
|
|
585
585
|
nn.init.zeros_(self.to_gamma.weight)
|
586
586
|
|
587
|
-
def forward(self, x, condition):
|
587
|
+
def forward(self, x, *, condition):
|
588
588
|
normed = self.ln(x)
|
589
589
|
gamma = self.to_gamma(condition)
|
590
590
|
return normed * (gamma + 1.)
|
@@ -616,7 +616,7 @@ class AdaptiveRMSNorm(Module):
|
|
616
616
|
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
617
617
|
nn.init.zeros_(self.to_gamma.weight)
|
618
618
|
|
619
|
-
def forward(self, x, condition):
|
619
|
+
def forward(self, x, *, condition):
|
620
620
|
normed = F.normalize(x, dim = -1)
|
621
621
|
gamma = self.to_gamma(condition)
|
622
622
|
return normed * self.scale * (gamma + 1.)
|
@@ -1161,6 +1161,8 @@ class AttentionLayers(Module):
|
|
1161
1161
|
use_adaptive_layernorm = False,
|
1162
1162
|
use_adaptive_rmsnorm = False,
|
1163
1163
|
dim_condition = None,
|
1164
|
+
adaptive_condition_mlp = False,
|
1165
|
+
adaptive_condition_mlp_expansion = 4,
|
1164
1166
|
alibi_pos_bias = False,
|
1165
1167
|
alibi_num_heads = None,
|
1166
1168
|
rel_pos_bias = False,
|
@@ -1270,6 +1272,17 @@ class AttentionLayers(Module):
|
|
1270
1272
|
need_condition = False
|
1271
1273
|
dim_condition = default(dim_condition, dim)
|
1272
1274
|
|
1275
|
+
self.adaptive_mlp = nn.Identity()
|
1276
|
+
dim_condition_mult = 1
|
1277
|
+
|
1278
|
+
if adaptive_condition_mlp:
|
1279
|
+
dim_condition_mult = adaptive_condition_mlp_expansion
|
1280
|
+
|
1281
|
+
self.adaptive_mlp = nn.Sequential(
|
1282
|
+
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1283
|
+
nn.SiLU()
|
1284
|
+
)
|
1285
|
+
|
1273
1286
|
if use_scalenorm:
|
1274
1287
|
norm_class = ScaleNorm
|
1275
1288
|
elif use_rmsnorm:
|
@@ -1278,10 +1291,10 @@ class AttentionLayers(Module):
|
|
1278
1291
|
norm_class = SimpleRMSNorm
|
1279
1292
|
elif use_adaptive_layernorm:
|
1280
1293
|
need_condition = True
|
1281
|
-
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1294
|
+
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
1282
1295
|
elif use_adaptive_rmsnorm:
|
1283
1296
|
need_condition = True
|
1284
|
-
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
|
1297
|
+
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
1285
1298
|
else:
|
1286
1299
|
norm_class = LayerNorm
|
1287
1300
|
|
@@ -1446,7 +1459,16 @@ class AttentionLayers(Module):
|
|
1446
1459
|
norm_kwargs = dict()
|
1447
1460
|
|
1448
1461
|
if self.need_condition:
|
1449
|
-
assert condition.
|
1462
|
+
assert condition.ndim in {2, 3}
|
1463
|
+
|
1464
|
+
if condition.ndim == 2:
|
1465
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
1466
|
+
|
1467
|
+
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
1468
|
+
|
1469
|
+
# maybe mlp
|
1470
|
+
|
1471
|
+
condition = self.adaptive_mlp(condition)
|
1450
1472
|
|
1451
1473
|
norm_kwargs.update(condition = condition)
|
1452
1474
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=NxNSAMgyHZEE-toXNcEO-unGwKfoqooN9oEQMDPPGfY,72268
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.19.dist-info/METADATA,sha256=2fQP9SmEys1grPsY2xHhB3sqkODtlaIz4UWtOR9pCM8,662
|
12
|
+
x_transformers-1.30.19.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|