x-transformers 1.30.19__py3-none-any.whl → 1.30.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +62 -17
- {x_transformers-1.30.19.dist-info → x_transformers-1.30.21.dist-info}/METADATA +1 -1
- {x_transformers-1.30.19.dist-info → x_transformers-1.30.21.dist-info}/RECORD +6 -6
- {x_transformers-1.30.19.dist-info → x_transformers-1.30.21.dist-info}/WHEEL +1 -1
- {x_transformers-1.30.19.dist-info → x_transformers-1.30.21.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.19.dist-info → x_transformers-1.30.21.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -710,6 +710,27 @@ class LayerScale(Module):
|
|
710
710
|
out, *rest = out
|
711
711
|
return out * self.gamma, *rest
|
712
712
|
|
713
|
+
class ConditionedLayerScale(Module):
|
714
|
+
def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
|
715
|
+
super().__init__()
|
716
|
+
self.fn = fn
|
717
|
+
|
718
|
+
dim_condition = default(dim_condition, dim)
|
719
|
+
self.to_gamma = nn.Linear(dim_condition, dim)
|
720
|
+
|
721
|
+
nn.init.zeros_(self.to_gamma.weight)
|
722
|
+
nn.init.constant_(self.to_gamma.bias, init_bias_value)
|
723
|
+
|
724
|
+
def forward(self, x, *, condition, **kwargs):
|
725
|
+
out = self.fn(x, **kwargs)
|
726
|
+
gamma = self.to_gamma(condition).sigmoid()
|
727
|
+
|
728
|
+
if isinstance(out, Tensor):
|
729
|
+
return out * gamma
|
730
|
+
|
731
|
+
out, *rest = out
|
732
|
+
return out * gamma, *rest
|
733
|
+
|
713
734
|
# feedforward
|
714
735
|
|
715
736
|
class GLU(Module):
|
@@ -1160,6 +1181,7 @@ class AttentionLayers(Module):
|
|
1160
1181
|
use_simple_rmsnorm = False,
|
1161
1182
|
use_adaptive_layernorm = False,
|
1162
1183
|
use_adaptive_rmsnorm = False,
|
1184
|
+
use_conditioned_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1163
1185
|
dim_condition = None,
|
1164
1186
|
adaptive_condition_mlp = False,
|
1165
1187
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1269,20 +1291,13 @@ class AttentionLayers(Module):
|
|
1269
1291
|
|
1270
1292
|
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
1271
1293
|
|
1272
|
-
|
1294
|
+
norm_need_condition = False
|
1273
1295
|
dim_condition = default(dim_condition, dim)
|
1274
|
-
|
1275
|
-
self.adaptive_mlp = nn.Identity()
|
1276
1296
|
dim_condition_mult = 1
|
1277
1297
|
|
1278
1298
|
if adaptive_condition_mlp:
|
1279
1299
|
dim_condition_mult = adaptive_condition_mlp_expansion
|
1280
1300
|
|
1281
|
-
self.adaptive_mlp = nn.Sequential(
|
1282
|
-
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1283
|
-
nn.SiLU()
|
1284
|
-
)
|
1285
|
-
|
1286
1301
|
if use_scalenorm:
|
1287
1302
|
norm_class = ScaleNorm
|
1288
1303
|
elif use_rmsnorm:
|
@@ -1290,17 +1305,17 @@ class AttentionLayers(Module):
|
|
1290
1305
|
elif use_simple_rmsnorm:
|
1291
1306
|
norm_class = SimpleRMSNorm
|
1292
1307
|
elif use_adaptive_layernorm:
|
1293
|
-
|
1308
|
+
norm_need_condition = True
|
1294
1309
|
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
1295
1310
|
elif use_adaptive_rmsnorm:
|
1296
|
-
|
1311
|
+
norm_need_condition = True
|
1297
1312
|
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
1298
1313
|
else:
|
1299
1314
|
norm_class = LayerNorm
|
1300
1315
|
|
1301
1316
|
norm_fn = partial(norm_class, dim)
|
1302
1317
|
|
1303
|
-
self.
|
1318
|
+
self.norm_need_condition = norm_need_condition
|
1304
1319
|
self.dim_condition = dim_condition
|
1305
1320
|
|
1306
1321
|
# determine default block layer type order
|
@@ -1317,10 +1332,30 @@ class AttentionLayers(Module):
|
|
1317
1332
|
|
1318
1333
|
# determine post branch wrapper
|
1319
1334
|
|
1335
|
+
assert at_most_one_of(use_layerscale, use_conditioned_layerscale)
|
1336
|
+
|
1320
1337
|
post_branch_fn = None
|
1338
|
+
post_branch_fn_needs_condition = False
|
1321
1339
|
|
1322
1340
|
if use_layerscale:
|
1323
1341
|
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
1342
|
+
elif use_conditioned_layerscale:
|
1343
|
+
post_branch_fn = partial(ConditionedLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
|
1344
|
+
post_branch_fn_needs_condition = True
|
1345
|
+
|
1346
|
+
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1347
|
+
|
1348
|
+
# setup mlp for conditioning
|
1349
|
+
|
1350
|
+
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
1351
|
+
|
1352
|
+
self.adaptive_mlp = nn.Identity()
|
1353
|
+
|
1354
|
+
if self.need_condition and adaptive_condition_mlp:
|
1355
|
+
self.adaptive_mlp = nn.Sequential(
|
1356
|
+
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1357
|
+
nn.SiLU()
|
1358
|
+
)
|
1324
1359
|
|
1325
1360
|
# zero init
|
1326
1361
|
|
@@ -1454,24 +1489,32 @@ class AttentionLayers(Module):
|
|
1454
1489
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1455
1490
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
1456
1491
|
|
1457
|
-
#
|
1492
|
+
# handle condition
|
1458
1493
|
|
1459
|
-
|
1494
|
+
if exists(condition):
|
1495
|
+
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
1460
1496
|
|
1461
|
-
if self.need_condition:
|
1462
1497
|
assert condition.ndim in {2, 3}
|
1463
1498
|
|
1464
1499
|
if condition.ndim == 2:
|
1465
1500
|
condition = rearrange(condition, 'b d -> b 1 d')
|
1466
1501
|
|
1467
|
-
|
1502
|
+
condition = self.adaptive_mlp(condition)
|
1468
1503
|
|
1469
|
-
|
1504
|
+
# setup maybe layernorm kwarg
|
1470
1505
|
|
1471
|
-
|
1506
|
+
norm_kwargs = dict()
|
1472
1507
|
|
1508
|
+
if self.norm_need_condition:
|
1473
1509
|
norm_kwargs.update(condition = condition)
|
1474
1510
|
|
1511
|
+
# maybe post branch fn conditioning (DiT paper's ada-ln-zero)
|
1512
|
+
|
1513
|
+
block_forward_kwargs = dict()
|
1514
|
+
|
1515
|
+
if self.post_branch_fn_needs_condition:
|
1516
|
+
block_forward_kwargs.update(condition = condition)
|
1517
|
+
|
1475
1518
|
# initialize accums
|
1476
1519
|
|
1477
1520
|
hiddens = []
|
@@ -1572,6 +1615,8 @@ class AttentionLayers(Module):
|
|
1572
1615
|
if layer_type == 'a' and exists(layer_mem):
|
1573
1616
|
layer_mem = pre_norm(layer_mem)
|
1574
1617
|
|
1618
|
+
block = partial(block, **block_forward_kwargs)
|
1619
|
+
|
1575
1620
|
if layer_type == 'a':
|
1576
1621
|
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
|
1577
1622
|
elif layer_type == 'c':
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=YFycvL28Y-Y-sr-_P2umC9qZ0nAfKFIBR_5mI1oKog0,73974
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.21.dist-info/METADATA,sha256=7HN3_fnSG0wGIGijv85bqfd8w4_K1m10Pv_PmI8kNrk,662
|
12
|
+
x_transformers-1.30.21.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
|
13
|
+
x_transformers-1.30.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|