x-transformers 1.30.16__py3-none-any.whl → 1.30.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +75 -13
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.18.dist-info}/METADATA +1 -1
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.18.dist-info}/RECORD +6 -6
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.18.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.18.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.16.dist-info → x_transformers-1.30.18.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -574,6 +574,21 @@ class LayerNorm(Module):
|
|
574
574
|
if version.parse(torch.__version__) >= version.parse('2.1.0'):
|
575
575
|
LayerNorm = partial(nn.LayerNorm, bias = False)
|
576
576
|
|
577
|
+
class AdaptiveLayerNorm(Module):
|
578
|
+
def __init__(self, dim, dim_condition = None):
|
579
|
+
super().__init__()
|
580
|
+
dim_condition = default(dim_condition, dim)
|
581
|
+
|
582
|
+
self.ln = nn.LayerNorm(dim, elementwise_affine = False)
|
583
|
+
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
584
|
+
|
585
|
+
nn.init.zeros_(self.to_gamma.weight)
|
586
|
+
|
587
|
+
def forward(self, x, condition):
|
588
|
+
normed = self.ln(x)
|
589
|
+
gamma = self.to_gamma(condition)
|
590
|
+
return normed * (gamma + 1.)
|
591
|
+
|
577
592
|
class ScaleNorm(Module):
|
578
593
|
def __init__(self, dim):
|
579
594
|
super().__init__()
|
@@ -592,6 +607,20 @@ class RMSNorm(Module):
|
|
592
607
|
def forward(self, x):
|
593
608
|
return F.normalize(x, dim = -1) * self.scale * self.g
|
594
609
|
|
610
|
+
class AdaptiveRMSNorm(Module):
|
611
|
+
def __init__(self, dim, dim_condition = None):
|
612
|
+
super().__init__()
|
613
|
+
self.scale = dim ** 0.5
|
614
|
+
dim_condition = default(dim_condition, dim)
|
615
|
+
|
616
|
+
self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
|
617
|
+
nn.init.zeros_(self.to_gamma.weight)
|
618
|
+
|
619
|
+
def forward(self, x, condition):
|
620
|
+
normed = F.normalize(x, dim = -1)
|
621
|
+
gamma = self.to_gamma(condition)
|
622
|
+
return normed * self.scale * (gamma + 1.)
|
623
|
+
|
595
624
|
class SimpleRMSNorm(Module):
|
596
625
|
def __init__(self, dim):
|
597
626
|
super().__init__()
|
@@ -1129,7 +1158,9 @@ class AttentionLayers(Module):
|
|
1129
1158
|
use_scalenorm = False,
|
1130
1159
|
use_rmsnorm = False,
|
1131
1160
|
use_simple_rmsnorm = False,
|
1132
|
-
|
1161
|
+
use_adaptive_layernorm = False,
|
1162
|
+
use_adaptive_rmsnorm = False,
|
1163
|
+
dim_condition = None,
|
1133
1164
|
alibi_pos_bias = False,
|
1134
1165
|
alibi_num_heads = None,
|
1135
1166
|
rel_pos_bias = False,
|
@@ -1198,9 +1229,10 @@ class AttentionLayers(Module):
|
|
1198
1229
|
# relative positional bias
|
1199
1230
|
|
1200
1231
|
flash_attn = attn_kwargs.get('flash', False)
|
1201
|
-
assert (
|
1232
|
+
assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
|
1202
1233
|
|
1203
1234
|
self.rel_pos = None
|
1235
|
+
|
1204
1236
|
if rel_pos_bias:
|
1205
1237
|
assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
|
1206
1238
|
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
|
@@ -1212,7 +1244,7 @@ class AttentionLayers(Module):
|
|
1212
1244
|
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
|
1213
1245
|
self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
|
1214
1246
|
|
1215
|
-
assert (
|
1247
|
+
assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
|
1216
1248
|
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
|
1217
1249
|
|
1218
1250
|
if resi_dual:
|
@@ -1233,7 +1265,10 @@ class AttentionLayers(Module):
|
|
1233
1265
|
|
1234
1266
|
# determine norm
|
1235
1267
|
|
1236
|
-
assert (
|
1268
|
+
assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
|
1269
|
+
|
1270
|
+
need_condition = False
|
1271
|
+
dim_condition = default(dim_condition, dim)
|
1237
1272
|
|
1238
1273
|
if use_scalenorm:
|
1239
1274
|
norm_class = ScaleNorm
|
@@ -1241,11 +1276,20 @@ class AttentionLayers(Module):
|
|
1241
1276
|
norm_class = RMSNorm
|
1242
1277
|
elif use_simple_rmsnorm:
|
1243
1278
|
norm_class = SimpleRMSNorm
|
1279
|
+
elif use_adaptive_layernorm:
|
1280
|
+
need_condition = True
|
1281
|
+
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
|
1282
|
+
elif use_adaptive_rmsnorm:
|
1283
|
+
need_condition = True
|
1284
|
+
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
|
1244
1285
|
else:
|
1245
1286
|
norm_class = LayerNorm
|
1246
1287
|
|
1247
1288
|
norm_fn = partial(norm_class, dim)
|
1248
1289
|
|
1290
|
+
self.need_condition = need_condition
|
1291
|
+
self.dim_condition = dim_condition
|
1292
|
+
|
1249
1293
|
# determine default block layer type order
|
1250
1294
|
|
1251
1295
|
if cross_attend and not only_cross:
|
@@ -1361,12 +1405,9 @@ class AttentionLayers(Module):
|
|
1361
1405
|
|
1362
1406
|
# all normalizations of the layer
|
1363
1407
|
|
1364
|
-
pre_branch_norm =
|
1365
|
-
|
1366
|
-
if not
|
1367
|
-
pre_branch_norm = norm_fn() if pre_norm else None
|
1368
|
-
post_branch_norm = norm_fn() if sandwich_norm else None
|
1369
|
-
post_main_norm = norm_fn() if not pre_norm else None
|
1408
|
+
pre_branch_norm = norm_fn() if pre_norm else None
|
1409
|
+
post_branch_norm = norm_fn() if sandwich_norm else None
|
1410
|
+
post_main_norm = norm_fn() if not pre_norm else None
|
1370
1411
|
|
1371
1412
|
norms = ModuleList([
|
1372
1413
|
pre_branch_norm,
|
@@ -1394,9 +1435,20 @@ class AttentionLayers(Module):
|
|
1394
1435
|
cache: LayerIntermediates | None = None,
|
1395
1436
|
cache_age = 1,
|
1396
1437
|
return_hiddens = False,
|
1397
|
-
rotary_pos_emb = None
|
1438
|
+
rotary_pos_emb = None,
|
1439
|
+
condition = None
|
1398
1440
|
):
|
1399
1441
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1442
|
+
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
1443
|
+
|
1444
|
+
# setup maybe layernorm kwarg
|
1445
|
+
|
1446
|
+
norm_kwargs = dict()
|
1447
|
+
|
1448
|
+
if self.need_condition:
|
1449
|
+
assert condition.shape[-1] == self.dim_condition
|
1450
|
+
|
1451
|
+
norm_kwargs.update(condition = condition)
|
1400
1452
|
|
1401
1453
|
# initialize accums
|
1402
1454
|
|
@@ -1487,6 +1539,11 @@ class AttentionLayers(Module):
|
|
1487
1539
|
|
1488
1540
|
pre_norm, post_branch_norm, post_main_norm = norm
|
1489
1541
|
|
1542
|
+
if self.need_condition:
|
1543
|
+
pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
|
1544
|
+
post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
|
1545
|
+
post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
|
1546
|
+
|
1490
1547
|
if exists(pre_norm):
|
1491
1548
|
x = pre_norm(x)
|
1492
1549
|
|
@@ -1523,10 +1580,15 @@ class AttentionLayers(Module):
|
|
1523
1580
|
if return_hiddens:
|
1524
1581
|
layer_hiddens.append(x)
|
1525
1582
|
|
1583
|
+
final_norm = self.final_norm
|
1584
|
+
|
1585
|
+
if self.need_condition:
|
1586
|
+
final_norm = maybe(partial)(final_norm, **norm_kwargs)
|
1587
|
+
|
1526
1588
|
if self.resi_dual:
|
1527
|
-
x = x +
|
1589
|
+
x = x + final_norm(outer_residual)
|
1528
1590
|
else:
|
1529
|
-
x =
|
1591
|
+
x = final_norm(x)
|
1530
1592
|
|
1531
1593
|
if not return_hiddens:
|
1532
1594
|
return x
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
|
12
|
+
x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|