x-transformers 1.30.19__py3-none-any.whl → 1.30.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -710,6 +710,27 @@ class LayerScale(Module):
710
710
  out, *rest = out
711
711
  return out * self.gamma, *rest
712
712
 
713
+ class ConditionedLayerScale(Module):
714
+ def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
715
+ super().__init__()
716
+ self.fn = fn
717
+
718
+ dim_condition = default(dim_condition, dim)
719
+ self.to_gamma = nn.Linear(dim_condition, dim)
720
+
721
+ nn.init.zeros_(self.to_gamma.weight)
722
+ nn.init.constant_(self.to_gamma.bias, init_bias_value)
723
+
724
+ def forward(self, x, *, condition, **kwargs):
725
+ out = self.fn(x, **kwargs)
726
+ gamma = self.to_gamma(condition).sigmoid()
727
+
728
+ if isinstance(out, Tensor):
729
+ return out * gamma
730
+
731
+ out, *rest = out
732
+ return out * gamma, *rest
733
+
713
734
  # feedforward
714
735
 
715
736
  class GLU(Module):
@@ -1160,6 +1181,7 @@ class AttentionLayers(Module):
1160
1181
  use_simple_rmsnorm = False,
1161
1182
  use_adaptive_layernorm = False,
1162
1183
  use_adaptive_rmsnorm = False,
1184
+ use_conditioned_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
1163
1185
  dim_condition = None,
1164
1186
  adaptive_condition_mlp = False,
1165
1187
  adaptive_condition_mlp_expansion = 4,
@@ -1269,20 +1291,13 @@ class AttentionLayers(Module):
1269
1291
 
1270
1292
  assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
1271
1293
 
1272
- need_condition = False
1294
+ norm_need_condition = False
1273
1295
  dim_condition = default(dim_condition, dim)
1274
-
1275
- self.adaptive_mlp = nn.Identity()
1276
1296
  dim_condition_mult = 1
1277
1297
 
1278
1298
  if adaptive_condition_mlp:
1279
1299
  dim_condition_mult = adaptive_condition_mlp_expansion
1280
1300
 
1281
- self.adaptive_mlp = nn.Sequential(
1282
- nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1283
- nn.SiLU()
1284
- )
1285
-
1286
1301
  if use_scalenorm:
1287
1302
  norm_class = ScaleNorm
1288
1303
  elif use_rmsnorm:
@@ -1290,17 +1305,17 @@ class AttentionLayers(Module):
1290
1305
  elif use_simple_rmsnorm:
1291
1306
  norm_class = SimpleRMSNorm
1292
1307
  elif use_adaptive_layernorm:
1293
- need_condition = True
1308
+ norm_need_condition = True
1294
1309
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
1295
1310
  elif use_adaptive_rmsnorm:
1296
- need_condition = True
1311
+ norm_need_condition = True
1297
1312
  norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
1298
1313
  else:
1299
1314
  norm_class = LayerNorm
1300
1315
 
1301
1316
  norm_fn = partial(norm_class, dim)
1302
1317
 
1303
- self.need_condition = need_condition
1318
+ self.norm_need_condition = norm_need_condition
1304
1319
  self.dim_condition = dim_condition
1305
1320
 
1306
1321
  # determine default block layer type order
@@ -1317,10 +1332,30 @@ class AttentionLayers(Module):
1317
1332
 
1318
1333
  # determine post branch wrapper
1319
1334
 
1335
+ assert at_most_one_of(use_layerscale, use_conditioned_layerscale)
1336
+
1320
1337
  post_branch_fn = None
1338
+ post_branch_fn_needs_condition = False
1321
1339
 
1322
1340
  if use_layerscale:
1323
1341
  post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
1342
+ elif use_conditioned_layerscale:
1343
+ post_branch_fn = partial(ConditionedLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
1344
+ post_branch_fn_needs_condition = True
1345
+
1346
+ self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1347
+
1348
+ # setup mlp for conditioning
1349
+
1350
+ self.need_condition = norm_need_condition or post_branch_fn_needs_condition
1351
+
1352
+ self.adaptive_mlp = nn.Identity()
1353
+
1354
+ if self.need_condition and adaptive_condition_mlp:
1355
+ self.adaptive_mlp = nn.Sequential(
1356
+ nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1357
+ nn.SiLU()
1358
+ )
1324
1359
 
1325
1360
  # zero init
1326
1361
 
@@ -1454,24 +1489,32 @@ class AttentionLayers(Module):
1454
1489
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1455
1490
  assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
1456
1491
 
1457
- # setup maybe layernorm kwarg
1492
+ # handle condition
1458
1493
 
1459
- norm_kwargs = dict()
1494
+ if exists(condition):
1495
+ assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1460
1496
 
1461
- if self.need_condition:
1462
1497
  assert condition.ndim in {2, 3}
1463
1498
 
1464
1499
  if condition.ndim == 2:
1465
1500
  condition = rearrange(condition, 'b d -> b 1 d')
1466
1501
 
1467
- assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1502
+ condition = self.adaptive_mlp(condition)
1468
1503
 
1469
- # maybe mlp
1504
+ # setup maybe layernorm kwarg
1470
1505
 
1471
- condition = self.adaptive_mlp(condition)
1506
+ norm_kwargs = dict()
1472
1507
 
1508
+ if self.norm_need_condition:
1473
1509
  norm_kwargs.update(condition = condition)
1474
1510
 
1511
+ # maybe post branch fn conditioning (DiT paper's ada-ln-zero)
1512
+
1513
+ block_forward_kwargs = dict()
1514
+
1515
+ if self.post_branch_fn_needs_condition:
1516
+ block_forward_kwargs.update(condition = condition)
1517
+
1475
1518
  # initialize accums
1476
1519
 
1477
1520
  hiddens = []
@@ -1572,6 +1615,8 @@ class AttentionLayers(Module):
1572
1615
  if layer_type == 'a' and exists(layer_mem):
1573
1616
  layer_mem = pre_norm(layer_mem)
1574
1617
 
1618
+ block = partial(block, **block_forward_kwargs)
1619
+
1575
1620
  if layer_type == 'a':
1576
1621
  out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1577
1622
  elif layer_type == 'c':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.19
3
+ Version: 1.30.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=NxNSAMgyHZEE-toXNcEO-unGwKfoqooN9oEQMDPPGfY,72268
7
+ x_transformers/x_transformers.py,sha256=YFycvL28Y-Y-sr-_P2umC9qZ0nAfKFIBR_5mI1oKog0,73974
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.19.dist-info/METADATA,sha256=2fQP9SmEys1grPsY2xHhB3sqkODtlaIz4UWtOR9pCM8,662
12
- x_transformers-1.30.19.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.19.dist-info/RECORD,,
10
+ x_transformers-1.30.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.21.dist-info/METADATA,sha256=7HN3_fnSG0wGIGijv85bqfd8w4_K1m10Pv_PmI8kNrk,662
12
+ x_transformers-1.30.21.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
13
+ x_transformers-1.30.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.21.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (70.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5