x-transformers 1.30.16__py3-none-any.whl → 1.30.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +57 -13
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.17.dist-info}/METADATA +1 -1
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.17.dist-info}/RECORD +6 -6
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.17.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.17.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.17.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -574,6 +574,21 @@ class LayerNorm(Module):
|
|
574
574
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
575
575
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
576
576
|
|
577
|
+
class AdaptiveLayerNorm(Module):
|
578
|
+
def __init__(self, dim, dim_condition = None):
|
579
|
+
super().__init__()
|
580
|
+
dim_condition = default(dim_condition, dim)
|
581
|
+
|
582
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
583
|
+
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
584
|
+
|
585
|
+
nn.init.zeros_(self.to_gamma.weight)
|
586
|
+
|
587
|
+
def forward(self, x, condition):
|
588
|
+
normed = self.ln(x)
|
589
|
+
gamma = self.to_gamma(condition)
|
590
|
+
return normed * (gamma + 1.)
|
591
|
+
|
577
592
|
class ScaleNorm(Module):
|
578
593
|
def __init__(self, dim):
|
579
594
|
super().__init__()
|
@@ -1129,7 +1144,8 @@ class AttentionLayers(Module):
|
|
1129
1144
|
use_scalenorm = False,
|
1130
1145
|
use_rmsnorm = False,
|
1131
1146
|
use_simple_rmsnorm = False,
|
1132
|
-
|
1147
|
+
use_adaptive_layernorm = False,
|
1148
|
+
dim_condition = None,
|
1133
1149
|
alibi_pos_bias = False,
|
1134
1150
|
alibi_num_heads = None,
|
1135
1151
|
rel_pos_bias = False,
|
@@ -1198,9 +1214,10 @@ class AttentionLayers(Module):
|
|
1198
1214
|
# relative positional bias
|
1199
1215
|
|
1200
1216
|
flash_attn = attn_kwargs.get('flash', False)
|
1201
|
-
assert (
|
1217
|
+
assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
|
1202
1218
|
|
1203
1219
|
self.rel_pos = None
|
1220
|
+
|
1204
1221
|
if rel_pos_bias:
|
1205
1222
|
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
1206
1223
|
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
@@ -1212,7 +1229,7 @@ class AttentionLayers(Module):
|
|
1212
1229
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1213
1230
|
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
1214
1231
|
|
1215
|
-
assert (
|
1232
|
+
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1216
1233
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
1217
1234
|
|
1218
1235
|
if resi_dual:
|
@@ -1233,7 +1250,10 @@ class AttentionLayers(Module):
|
|
1233
1250
|
|
1234
1251
|
# determine norm
|
1235
1252
|
|
1236
|
-
assert (
|
1253
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
|
1254
|
+
|
1255
|
+
need_condition = False
|
1256
|
+
dim_condition = default(dim_condition, dim)
|
1237
1257
|
|
1238
1258
|
if use_scalenorm:
|
1239
1259
|
norm_class = ScaleNorm
|
@@ -1241,11 +1261,17 @@ class AttentionLayers(Module):
|
|
1241
1261
|
norm_class = RMSNorm
|
1242
1262
|
elif use_simple_rmsnorm:
|
1243
1263
|
norm_class = SimpleRMSNorm
|
1264
|
+
elif use_adaptive_layernorm:
|
1265
|
+
need_condition = True
|
1266
|
+
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1244
1267
|
else:
|
1245
1268
|
norm_class = LayerNorm
|
1246
1269
|
|
1247
1270
|
norm_fn = partial(norm_class, dim)
|
1248
1271
|
|
1272
|
+
self.need_condition = need_condition
|
1273
|
+
self.dim_condition = dim_condition
|
1274
|
+
|
1249
1275
|
# determine default block layer type order
|
1250
1276
|
|
1251
1277
|
if cross_attend and not only_cross:
|
@@ -1361,12 +1387,9 @@ class AttentionLayers(Module):
|
|
1361
1387
|
|
1362
1388
|
# all normalizations of the layer
|
1363
1389
|
|
1364
|
-
pre_branch_norm =
|
1365
|
-
|
1366
|
-
if not
|
1367
|
-
pre_branch_norm = norm_fn() if pre_norm else None
|
1368
|
-
post_branch_norm = norm_fn() if sandwich_norm else None
|
1369
|
-
post_main_norm = norm_fn() if not pre_norm else None
|
1390
|
+
pre_branch_norm = norm_fn() if pre_norm else None
|
1391
|
+
post_branch_norm = norm_fn() if sandwich_norm else None
|
1392
|
+
post_main_norm = norm_fn() if not pre_norm else None
|
1370
1393
|
|
1371
1394
|
norms = ModuleList([
|
1372
1395
|
pre_branch_norm,
|
@@ -1394,9 +1417,20 @@ class AttentionLayers(Module):
|
|
1394
1417
|
cache: LayerIntermediates | None = None,
|
1395
1418
|
cache_age = 1,
|
1396
1419
|
return_hiddens = False,
|
1397
|
-
rotary_pos_emb = None
|
1420
|
+
rotary_pos_emb = None,
|
1421
|
+
condition = None
|
1398
1422
|
):
|
1399
1423
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1424
|
+
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
1425
|
+
|
1426
|
+
# setup maybe layernorm kwarg
|
1427
|
+
|
1428
|
+
norm_kwargs = dict()
|
1429
|
+
|
1430
|
+
if self.need_condition:
|
1431
|
+
assert condition.shape[-1] == self.dim_condition
|
1432
|
+
|
1433
|
+
norm_kwargs.update(condition = condition)
|
1400
1434
|
|
1401
1435
|
# initialize accums
|
1402
1436
|
|
@@ -1487,6 +1521,11 @@ class AttentionLayers(Module):
|
|
1487
1521
|
|
1488
1522
|
pre_norm, post_branch_norm, post_main_norm = norm
|
1489
1523
|
|
1524
|
+
if self.need_condition:
|
1525
|
+
pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
|
1526
|
+
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1527
|
+
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1528
|
+
|
1490
1529
|
if exists(pre_norm):
|
1491
1530
|
x = pre_norm(x)
|
1492
1531
|
|
@@ -1523,10 +1562,15 @@ class AttentionLayers(Module):
|
|
1523
1562
|
if return_hiddens:
|
1524
1563
|
layer_hiddens.append(x)
|
1525
1564
|
|
1565
|
+
final_norm = self.final_norm
|
1566
|
+
|
1567
|
+
if self.need_condition:
|
1568
|
+
final_norm = maybe(partial)(final_norm, **norm_kwargs)
|
1569
|
+
|
1526
1570
|
if self.resi_dual:
|
1527
|
-
x = x +
|
1571
|
+
x = x + final_norm(outer_residual)
|
1528
1572
|
else:
|
1529
|
-
x =
|
1573
|
+
x = final_norm(x)
|
1530
1574
|
|
1531
1575
|
if not return_hiddens:
|
1532
1576
|
return x
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=kvGX4Ib1gSYV6pmVXF6P9bcutRsx5bif_XhkbG4DOZ8,70738
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.17.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.17.dist-info/METADATA,sha256=m_d5lvKUbiN8xS7Dx4gI5I8dHtzEa1ccp4MuKcG5O9w,662
|
12
|
+
x_transformers-1.30.17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.17.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|