x-transformers 1.30.16__py3-none-any.whl → 1.30.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -574,6 +574,21 @@ class LayerNorm(Module):
574
574
  if version.parse(torch.__version__) >= version.parse('2.1.0'):
575
575
  LayerNorm = partial(nn.LayerNorm, bias = False)
576
576
 
577
+ class AdaptiveLayerNorm(Module):
578
+ def __init__(self, dim, dim_condition = None):
579
+ super().__init__()
580
+ dim_condition = default(dim_condition, dim)
581
+
582
+ self.ln = nn.LayerNorm(dim, elementwise_affine = False)
583
+ self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
584
+
585
+ nn.init.zeros_(self.to_gamma.weight)
586
+
587
+ def forward(self, x, condition):
588
+ normed = self.ln(x)
589
+ gamma = self.to_gamma(condition)
590
+ return normed * (gamma + 1.)
591
+
577
592
  class ScaleNorm(Module):
578
593
  def __init__(self, dim):
579
594
  super().__init__()
@@ -1129,7 +1144,8 @@ class AttentionLayers(Module):
1129
1144
  use_scalenorm = False,
1130
1145
  use_rmsnorm = False,
1131
1146
  use_simple_rmsnorm = False,
1132
- no_pre_or_postnorm = False,
1147
+ use_adaptive_layernorm = False,
1148
+ dim_condition = None,
1133
1149
  alibi_pos_bias = False,
1134
1150
  alibi_num_heads = None,
1135
1151
  rel_pos_bias = False,
@@ -1198,9 +1214,10 @@ class AttentionLayers(Module):
1198
1214
  # relative positional bias
1199
1215
 
1200
1216
  flash_attn = attn_kwargs.get('flash', False)
1201
- assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1217
+ assert at_most_one_of(rel_pos_bias, dynamic_pos_bias, alibi_pos_bias), 'you can only choose up to one of t5, alibi, or dynamic positional bias'
1202
1218
 
1203
1219
  self.rel_pos = None
1220
+
1204
1221
  if rel_pos_bias:
1205
1222
  assert not flash_attn, 'flash attention not compatible with t5 relative positional bias'
1206
1223
  self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
@@ -1212,7 +1229,7 @@ class AttentionLayers(Module):
1212
1229
  assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
1213
1230
  self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads)
1214
1231
 
1215
- assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
1232
+ assert at_most_one_of(sandwich_norm, resi_dual), 'either sandwich norm or resiDual is selected, but not both'
1216
1233
  assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
1217
1234
 
1218
1235
  if resi_dual:
@@ -1233,7 +1250,10 @@ class AttentionLayers(Module):
1233
1250
 
1234
1251
  # determine norm
1235
1252
 
1236
- assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1253
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1254
+
1255
+ need_condition = False
1256
+ dim_condition = default(dim_condition, dim)
1237
1257
 
1238
1258
  if use_scalenorm:
1239
1259
  norm_class = ScaleNorm
@@ -1241,11 +1261,17 @@ class AttentionLayers(Module):
1241
1261
  norm_class = RMSNorm
1242
1262
  elif use_simple_rmsnorm:
1243
1263
  norm_class = SimpleRMSNorm
1264
+ elif use_adaptive_layernorm:
1265
+ need_condition = True
1266
+ norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1244
1267
  else:
1245
1268
  norm_class = LayerNorm
1246
1269
 
1247
1270
  norm_fn = partial(norm_class, dim)
1248
1271
 
1272
+ self.need_condition = need_condition
1273
+ self.dim_condition = dim_condition
1274
+
1249
1275
  # determine default block layer type order
1250
1276
 
1251
1277
  if cross_attend and not only_cross:
@@ -1361,12 +1387,9 @@ class AttentionLayers(Module):
1361
1387
 
1362
1388
  # all normalizations of the layer
1363
1389
 
1364
- pre_branch_norm = post_branch_norm = post_main_norm = None
1365
-
1366
- if not no_pre_or_postnorm:
1367
- pre_branch_norm = norm_fn() if pre_norm else None
1368
- post_branch_norm = norm_fn() if sandwich_norm else None
1369
- post_main_norm = norm_fn() if not pre_norm else None
1390
+ pre_branch_norm = norm_fn() if pre_norm else None
1391
+ post_branch_norm = norm_fn() if sandwich_norm else None
1392
+ post_main_norm = norm_fn() if not pre_norm else None
1370
1393
 
1371
1394
  norms = ModuleList([
1372
1395
  pre_branch_norm,
@@ -1394,9 +1417,20 @@ class AttentionLayers(Module):
1394
1417
  cache: LayerIntermediates | None = None,
1395
1418
  cache_age = 1,
1396
1419
  return_hiddens = False,
1397
- rotary_pos_emb = None
1420
+ rotary_pos_emb = None,
1421
+ condition = None
1398
1422
  ):
1399
1423
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1424
+ assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
1425
+
1426
+ # setup maybe layernorm kwarg
1427
+
1428
+ norm_kwargs = dict()
1429
+
1430
+ if self.need_condition:
1431
+ assert condition.shape[-1] == self.dim_condition
1432
+
1433
+ norm_kwargs.update(condition = condition)
1400
1434
 
1401
1435
  # initialize accums
1402
1436
 
@@ -1487,6 +1521,11 @@ class AttentionLayers(Module):
1487
1521
 
1488
1522
  pre_norm, post_branch_norm, post_main_norm = norm
1489
1523
 
1524
+ if self.need_condition:
1525
+ pre_norm = maybe(partial)(pre_norm, **norm_kwargs)
1526
+ post_branch_norm = maybe(partial)(post_branch_norm, **norm_kwargs)
1527
+ post_main_norm = maybe(partial)(post_main_norm, **norm_kwargs)
1528
+
1490
1529
  if exists(pre_norm):
1491
1530
  x = pre_norm(x)
1492
1531
 
@@ -1523,10 +1562,15 @@ class AttentionLayers(Module):
1523
1562
  if return_hiddens:
1524
1563
  layer_hiddens.append(x)
1525
1564
 
1565
+ final_norm = self.final_norm
1566
+
1567
+ if self.need_condition:
1568
+ final_norm = maybe(partial)(final_norm, **norm_kwargs)
1569
+
1526
1570
  if self.resi_dual:
1527
- x = x + self.final_norm(outer_residual)
1571
+ x = x + final_norm(outer_residual)
1528
1572
  else:
1529
- x = self.final_norm(x)
1573
+ x = final_norm(x)
1530
1574
 
1531
1575
  if not return_hiddens:
1532
1576
  return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.16
3
+ Version: 1.30.17
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=G4aF6SHVvhcu8OHFwCj0ZlJahEeGR1lUaRglNHcK74k,69225
7
+ x_transformers/x_transformers.py,sha256=kvGX4Ib1gSYV6pmVXF6P9bcutRsx5bif_XhkbG4DOZ8,70738
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.16.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.16.dist-info/METADATA,sha256=6BifB-CeW-wD7SjuXrKCC5dbJUnq_iLEMl0DXcszvt0,662
12
- x_transformers-1.30.16.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.16.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.16.dist-info/RECORD,,
10
+ x_transformers-1.30.17.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.17.dist-info/METADATA,sha256=m_d5lvKUbiN8xS7Dx4gI5I8dHtzEa1ccp4MuKcG5O9w,662
12
+ x_transformers-1.30.17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.17.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.17.dist-info/RECORD,,