x-transformers 1.30.18__py3-none-any.whl → 1.30.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +28 -5
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.20.dist-info}/METADATA +1 -1
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.20.dist-info}/RECORD +6 -6
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.20.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.20.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.18.dist-info → x_transformers-1.30.20.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
|
|
584
584
|
|
585
585
|
nn.init.zeros_(self.to_gamma.weight)
|
586
586
|
|
587
|
-
def forward(self, x, condition):
|
587
|
+
def forward(self, x, *, condition):
|
588
588
|
normed = self.ln(x)
|
589
589
|
gamma = self.to_gamma(condition)
|
590
590
|
return normed * (gamma + 1.)
|
@@ -616,7 +616,7 @@ class AdaptiveRMSNorm(Module):
|
|
616
616
|
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
617
617
|
nn.init.zeros_(self.to_gamma.weight)
|
618
618
|
|
619
|
-
def forward(self, x, condition):
|
619
|
+
def forward(self, x, *, condition):
|
620
620
|
normed = F.normalize(x, dim = -1)
|
621
621
|
gamma = self.to_gamma(condition)
|
622
622
|
return normed * self.scale * (gamma + 1.)
|
@@ -1161,6 +1161,8 @@ class AttentionLayers(Module):
|
|
1161
1161
|
use_adaptive_layernorm = False,
|
1162
1162
|
use_adaptive_rmsnorm = False,
|
1163
1163
|
dim_condition = None,
|
1164
|
+
adaptive_condition_mlp = False,
|
1165
|
+
adaptive_condition_mlp_expansion = 4,
|
1164
1166
|
alibi_pos_bias = False,
|
1165
1167
|
alibi_num_heads = None,
|
1166
1168
|
rel_pos_bias = False,
|
@@ -1269,6 +1271,10 @@ class AttentionLayers(Module):
|
|
1269
1271
|
|
1270
1272
|
need_condition = False
|
1271
1273
|
dim_condition = default(dim_condition, dim)
|
1274
|
+
dim_condition_mult = 1
|
1275
|
+
|
1276
|
+
if adaptive_condition_mlp:
|
1277
|
+
dim_condition_mult = adaptive_condition_mlp_expansion
|
1272
1278
|
|
1273
1279
|
if use_scalenorm:
|
1274
1280
|
norm_class = ScaleNorm
|
@@ -1278,15 +1284,23 @@ class AttentionLayers(Module):
|
|
1278
1284
|
norm_class = SimpleRMSNorm
|
1279
1285
|
elif use_adaptive_layernorm:
|
1280
1286
|
need_condition = True
|
1281
|
-
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1287
|
+
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
1282
1288
|
elif use_adaptive_rmsnorm:
|
1283
1289
|
need_condition = True
|
1284
|
-
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
|
1290
|
+
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
1285
1291
|
else:
|
1286
1292
|
norm_class = LayerNorm
|
1287
1293
|
|
1288
1294
|
norm_fn = partial(norm_class, dim)
|
1289
1295
|
|
1296
|
+
self.adaptive_mlp = nn.Identity()
|
1297
|
+
|
1298
|
+
if need_condition and adaptive_condition_mlp:
|
1299
|
+
self.adaptive_mlp = nn.Sequential(
|
1300
|
+
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1301
|
+
nn.SiLU()
|
1302
|
+
)
|
1303
|
+
|
1290
1304
|
self.need_condition = need_condition
|
1291
1305
|
self.dim_condition = dim_condition
|
1292
1306
|
|
@@ -1446,7 +1460,16 @@ class AttentionLayers(Module):
|
|
1446
1460
|
norm_kwargs = dict()
|
1447
1461
|
|
1448
1462
|
if self.need_condition:
|
1449
|
-
assert condition.
|
1463
|
+
assert condition.ndim in {2, 3}
|
1464
|
+
|
1465
|
+
if condition.ndim == 2:
|
1466
|
+
condition = rearrange(condition, 'b d -> b 1 d')
|
1467
|
+
|
1468
|
+
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
1469
|
+
|
1470
|
+
# maybe mlp
|
1471
|
+
|
1472
|
+
condition = self.adaptive_mlp(condition)
|
1450
1473
|
|
1451
1474
|
norm_kwargs.update(condition = condition)
|
1452
1475
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=kd0H1tsw3SynfQu7xjuzacnuYimVhVMVxLID_I_pM8A,72322
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.20.dist-info/METADATA,sha256=OwTLO7xb31tvGd5vHK1XAU70ireyNSlrcDlg2zUJUuY,662
|
12
|
+
x_transformers-1.30.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|