x-transformers 1.30.20__py3-none-any.whl → 1.30.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +62 -18
- {x_transformers-1.30.20.dist-info → x_transformers-1.30.22.dist-info}/METADATA +1 -1
- {x_transformers-1.30.20.dist-info → x_transformers-1.30.22.dist-info}/RECORD +6 -6
- {x_transformers-1.30.20.dist-info → x_transformers-1.30.22.dist-info}/WHEEL +1 -1
- {x_transformers-1.30.20.dist-info → x_transformers-1.30.22.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.20.dist-info → x_transformers-1.30.22.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -710,6 +710,27 @@ class LayerScale(Module):
|
|
710
710
|
out, *rest = out
|
711
711
|
return out * self.gamma, *rest
|
712
712
|
|
713
|
+
class AdaptiveLayerScale(Module):
|
714
|
+
def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
|
715
|
+
super().__init__()
|
716
|
+
self.fn = fn
|
717
|
+
|
718
|
+
dim_condition = default(dim_condition, dim)
|
719
|
+
self.to_gamma = nn.Linear(dim_condition, dim)
|
720
|
+
|
721
|
+
nn.init.zeros_(self.to_gamma.weight)
|
722
|
+
nn.init.constant_(self.to_gamma.bias, init_bias_value)
|
723
|
+
|
724
|
+
def forward(self, x, *, condition, **kwargs):
|
725
|
+
out = self.fn(x, **kwargs)
|
726
|
+
gamma = self.to_gamma(condition).sigmoid()
|
727
|
+
|
728
|
+
if isinstance(out, Tensor):
|
729
|
+
return out * gamma
|
730
|
+
|
731
|
+
out, *rest = out
|
732
|
+
return out * gamma, *rest
|
733
|
+
|
713
734
|
# feedforward
|
714
735
|
|
715
736
|
class GLU(Module):
|
@@ -1160,6 +1181,7 @@ class AttentionLayers(Module):
|
|
1160
1181
|
use_simple_rmsnorm = False,
|
1161
1182
|
use_adaptive_layernorm = False,
|
1162
1183
|
use_adaptive_rmsnorm = False,
|
1184
|
+
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1163
1185
|
dim_condition = None,
|
1164
1186
|
adaptive_condition_mlp = False,
|
1165
1187
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1269,7 +1291,7 @@ class AttentionLayers(Module):
|
|
1269
1291
|
|
1270
1292
|
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
1271
1293
|
|
1272
|
-
|
1294
|
+
norm_need_condition = False
|
1273
1295
|
dim_condition = default(dim_condition, dim)
|
1274
1296
|
dim_condition_mult = 1
|
1275
1297
|
|
@@ -1283,25 +1305,17 @@ class AttentionLayers(Module):
|
|
1283
1305
|
elif use_simple_rmsnorm:
|
1284
1306
|
norm_class = SimpleRMSNorm
|
1285
1307
|
elif use_adaptive_layernorm:
|
1286
|
-
|
1308
|
+
norm_need_condition = True
|
1287
1309
|
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
|
1288
1310
|
elif use_adaptive_rmsnorm:
|
1289
|
-
|
1311
|
+
norm_need_condition = True
|
1290
1312
|
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
|
1291
1313
|
else:
|
1292
1314
|
norm_class = LayerNorm
|
1293
1315
|
|
1294
1316
|
norm_fn = partial(norm_class, dim)
|
1295
1317
|
|
1296
|
-
self.
|
1297
|
-
|
1298
|
-
if need_condition and adaptive_condition_mlp:
|
1299
|
-
self.adaptive_mlp = nn.Sequential(
|
1300
|
-
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1301
|
-
nn.SiLU()
|
1302
|
-
)
|
1303
|
-
|
1304
|
-
self.need_condition = need_condition
|
1318
|
+
self.norm_need_condition = norm_need_condition
|
1305
1319
|
self.dim_condition = dim_condition
|
1306
1320
|
|
1307
1321
|
# determine default block layer type order
|
@@ -1318,10 +1332,30 @@ class AttentionLayers(Module):
|
|
1318
1332
|
|
1319
1333
|
# determine post branch wrapper
|
1320
1334
|
|
1335
|
+
assert at_most_one_of(use_layerscale, use_adaptive_layerscale)
|
1336
|
+
|
1321
1337
|
post_branch_fn = None
|
1338
|
+
post_branch_fn_needs_condition = False
|
1322
1339
|
|
1323
1340
|
if use_layerscale:
|
1324
1341
|
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
1342
|
+
elif use_adaptive_layerscale:
|
1343
|
+
post_branch_fn = partial(AdaptiveLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
|
1344
|
+
post_branch_fn_needs_condition = True
|
1345
|
+
|
1346
|
+
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
1347
|
+
|
1348
|
+
# setup mlp for conditioning
|
1349
|
+
|
1350
|
+
self.need_condition = norm_need_condition or post_branch_fn_needs_condition
|
1351
|
+
|
1352
|
+
self.adaptive_mlp = nn.Identity()
|
1353
|
+
|
1354
|
+
if self.need_condition and adaptive_condition_mlp:
|
1355
|
+
self.adaptive_mlp = nn.Sequential(
|
1356
|
+
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
|
1357
|
+
nn.SiLU()
|
1358
|
+
)
|
1325
1359
|
|
1326
1360
|
# zero init
|
1327
1361
|
|
@@ -1455,24 +1489,32 @@ class AttentionLayers(Module):
|
|
1455
1489
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1456
1490
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
1457
1491
|
|
1458
|
-
#
|
1492
|
+
# handle condition
|
1459
1493
|
|
1460
|
-
|
1494
|
+
if exists(condition):
|
1495
|
+
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
|
1461
1496
|
|
1462
|
-
if self.need_condition:
|
1463
1497
|
assert condition.ndim in {2, 3}
|
1464
1498
|
|
1465
1499
|
if condition.ndim == 2:
|
1466
1500
|
condition = rearrange(condition, 'b d -> b 1 d')
|
1467
1501
|
|
1468
|
-
|
1502
|
+
condition = self.adaptive_mlp(condition)
|
1469
1503
|
|
1470
|
-
|
1504
|
+
# setup maybe layernorm kwarg
|
1471
1505
|
|
1472
|
-
|
1506
|
+
norm_kwargs = dict()
|
1473
1507
|
|
1508
|
+
if self.norm_need_condition:
|
1474
1509
|
norm_kwargs.update(condition = condition)
|
1475
1510
|
|
1511
|
+
# maybe post branch fn conditioning (DiT paper's ada-ln-zero)
|
1512
|
+
|
1513
|
+
block_forward_kwargs = dict()
|
1514
|
+
|
1515
|
+
if self.post_branch_fn_needs_condition:
|
1516
|
+
block_forward_kwargs.update(condition = condition)
|
1517
|
+
|
1476
1518
|
# initialize accums
|
1477
1519
|
|
1478
1520
|
hiddens = []
|
@@ -1573,6 +1615,8 @@ class AttentionLayers(Module):
|
|
1573
1615
|
if layer_type == 'a' and exists(layer_mem):
|
1574
1616
|
layer_mem = pre_norm(layer_mem)
|
1575
1617
|
|
1618
|
+
block = partial(block, **block_forward_kwargs)
|
1619
|
+
|
1576
1620
|
if layer_type == 'a':
|
1577
1621
|
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
|
1578
1622
|
elif layer_type == 'c':
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=z1LzQMpJSNgalsnZJB-EsxqQk6aF0rQrv98VkiBTHnc,73959
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.22.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.22.dist-info/METADATA,sha256=aU33r8OyivazXN7QW5np0LdRaZGJY3x7pBSx55iFGdA,662
|
12
|
+
x_transformers-1.30.22.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
|
13
|
+
x_transformers-1.30.22.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.22.dist-info/RECORD,,
|
File without changes
|
File without changes
|