x-transformers 1.30.16__py3-none-any.whl → 1.30.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -574,6 +574,21 @@ class LayerNorm(Module):
574
574
  if version.parse(torch.__version__) >= version.parse('2.1.0'):
575
575
  LayerNorm = partial(nn.LayerNorm, bias = False)
576
576
 
577
+ class AdaptiveLayerNorm(Module):
578
+ def __init__(self, dim, dim_condition = None):
579
+ super().__init__()
580
+ dim_condition = default(dim_condition, dim)
581
+
582
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
583
+ self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
584
+
585
+ nn.init.zeros_(self.to_gamma.weight)
586
+
587
+ def forward(self, x, condition):
588
+ normed = self.ln(x)
589
+ gamma = self.to_gamma(condition)
590
+ return normed * (gamma + 1.)
591
+
577
592
  class ScaleNorm(Module):
578
593
  def __init__(self, dim):
579
594
  super().__init__()
@@ -592,6 +607,20 @@ class RMSNorm(Module):
592
607
  def forward(self, x):
593
608
  return F.normalize(x, dim = -1) * self.scale * self.g
594
609
 
610
+ class AdaptiveRMSNorm(Module):
611
+ def __init__(self, dim, dim_condition = None):
612
+ super().__init__()
613
+ self.scale = dim ** 0.5
614
+ dim_condition = default(dim_condition, dim)
615
+
616
+ self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
617
+ nn.init.zeros_(self.to_gamma.weight)
618
+
619
+ def forward(self, x, condition):
620
+ normed = F.normalize(x, dim = -1)
621
+ gamma = self.to_gamma(condition)
622
+ return normed * self.scale * (gamma + 1.)
623
+
595
624
  class SimpleRMSNorm(Module):
596
625
  def __init__(self, dim):
597
626
  super().__init__()
@@ -1129,7 +1158,9 @@ class AttentionLayers(Module):
1129
1158
  use_scalenorm = False,
1130
1159
  use_rmsnorm = False,
1131
1160
  use_simple_rmsnorm = False,
1132
- no_pre_or_postnorm = False,
1161
+ use_adaptive_layernorm = False,
1162
+ use_adaptive_rmsnorm = False,
1163
+ dim_condition = None,
1133
1164
  alibi_pos_bias = False,
1134
1165
  alibi_num_heads = None,
1135
1166
  rel_pos_bias = False,
@@ -1198,9 +1229,10 @@ class AttentionLayers(Module):
1198
1229
  # relative positional bias
1199
1230
 
1200
1231
  flash_attn = attn_kwargs.get('flash', False)
1201
- assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1232
+ assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1202
1233
 
1203
1234
  self.rel_pos = None
1235
+
1204
1236
  if rel_pos_bias:
1205
1237
  assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1206
1238
  self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
@@ -1212,7 +1244,7 @@ class AttentionLayers(Module):
1212
1244
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1213
1245
  self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1214
1246
 
1215
- assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1247
+ assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1216
1248
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1217
1249
 
1218
1250
  if resi_dual:
@@ -1233,7 +1265,10 @@ class AttentionLayers(Module):
1233
1265
 
1234
1266
  # determine norm
1235
1267
 
1236
- assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1268
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
1269
+
1270
+ need_condition = False
1271
+ dim_condition = default(dim_condition, dim)
1237
1272
 
1238
1273
  if use_scalenorm:
1239
1274
  norm_class = ScaleNorm
@@ -1241,11 +1276,20 @@ class AttentionLayers(Module):
1241
1276
  norm_class = RMSNorm
1242
1277
  elif use_simple_rmsnorm:
1243
1278
  norm_class = SimpleRMSNorm
1279
+ elif use_adaptive_layernorm:
1280
+ need_condition = True
1281
+ norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1282
+ elif use_adaptive_rmsnorm:
1283
+ need_condition = True
1284
+ norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
1244
1285
  else:
1245
1286
  norm_class = LayerNorm
1246
1287
 
1247
1288
  norm_fn = partial(norm_class, dim)
1248
1289
 
1290
+ self.need_condition = need_condition
1291
+ self.dim_condition = dim_condition
1292
+
1249
1293
  # determine default block layer type order
1250
1294
 
1251
1295
  if cross_attend and not only_cross:
@@ -1361,12 +1405,9 @@ class AttentionLayers(Module):
1361
1405
 
1362
1406
  # all normalizations of the layer
1363
1407
 
1364
- pre_branch_norm = post_branch_norm = post_main_norm = None
1365
-
1366
- if not no_pre_or_postnorm:
1367
- pre_branch_norm = norm_fn() if pre_norm else None
1368
- post_branch_norm = norm_fn() if sandwich_norm else None
1369
- post_main_norm = norm_fn() if not pre_norm else None
1408
+ pre_branch_norm = norm_fn() if pre_norm else None
1409
+ post_branch_norm = norm_fn() if sandwich_norm else None
1410
+ post_main_norm = norm_fn() if not pre_norm else None
1370
1411
 
1371
1412
  norms = ModuleList([
1372
1413
  pre_branch_norm,
@@ -1394,9 +1435,20 @@ class AttentionLayers(Module):
1394
1435
  cache: LayerIntermediates | None = None,
1395
1436
  cache_age = 1,
1396
1437
  return_hiddens = False,
1397
- rotary_pos_emb = None
1438
+ rotary_pos_emb = None,
1439
+ condition = None
1398
1440
  ):
1399
1441
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1442
+ assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
1443
+
1444
+ # setup maybe layernorm kwarg
1445
+
1446
+ norm_kwargs = dict()
1447
+
1448
+ if self.need_condition:
1449
+ assert condition.shape[-1] == self.dim_condition
1450
+
1451
+ norm_kwargs.update(condition = condition)
1400
1452
 
1401
1453
  # initialize accums
1402
1454
 
@@ -1487,6 +1539,11 @@ class AttentionLayers(Module):
1487
1539
 
1488
1540
  pre_norm, post_branch_norm, post_main_norm = norm
1489
1541
 
1542
+ if self.need_condition:
1543
+ pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
1544
+ post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1545
+ post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1546
+
1490
1547
  if exists(pre_norm):
1491
1548
  x = pre_norm(x)
1492
1549
 
@@ -1523,10 +1580,15 @@ class AttentionLayers(Module):
1523
1580
  if return_hiddens:
1524
1581
  layer_hiddens.append(x)
1525
1582
 
1583
+ final_norm = self.final_norm
1584
+
1585
+ if self.need_condition:
1586
+ final_norm = maybe(partial)(final_norm, **norm_kwargs)
1587
+
1526
1588
  if self.resi_dual:
1527
- x = x + self.final_norm(outer_residual)
1589
+ x = x + final_norm(outer_residual)
1528
1590
  else:
1529
- x = self.final_norm(x)
1591
+ x = final_norm(x)
1530
1592
 
1531
1593
  if not return_hiddens:
1532
1594
  return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.16
3
+ Version: 1.30.18
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=G4aF6SHVvhcu8OHFwCj0ZlJahEeGR1lUaRglNHcK74k,69225
7
+ x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.16.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.16.dist-info/METADATA,sha256=6BifB-CeW-wD7SjuXrKCC5dbJUnq_iLEMl0DXcszvt0,662
12
- x_transformers-1.30.16.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.16.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.16.dist-info/RECORD,,
10
+ x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
12
+ x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.18.dist-info/RECORD,,