x-transformers 1.30.20__py3-none-any.whl → 1.30.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -710,6 +710,27 @@ class LayerScale(Module):
710
710
  out, *rest = out
711
711
  return out * self.gamma, *rest
712
712
 
713
+ class ConditionedLayerScale(Module):
714
+ def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
715
+ super().__init__()
716
+ self.fn = fn
717
+
718
+ dim_condition = default(dim_condition, dim)
719
+ self.to_gamma = nn.Linear(dim_condition, dim)
720
+
721
+ nn.init.zeros_(self.to_gamma.weight)
722
+ nn.init.constant_(self.to_gamma.bias, init_bias_value)
723
+
724
+ def forward(self, x, *, condition, **kwargs):
725
+ out = self.fn(x, **kwargs)
726
+ gamma = self.to_gamma(condition).sigmoid()
727
+
728
+ if isinstance(out, Tensor):
729
+ return out * gamma
730
+
731
+ out, *rest = out
732
+ return out * gamma, *rest
733
+
713
734
  # feedforward
714
735
 
715
736
  class GLU(Module):
@@ -1160,6 +1181,7 @@ class AttentionLayers(Module):
1160
1181
  use_simple_rmsnorm = False,
1161
1182
  use_adaptive_layernorm = False,
1162
1183
  use_adaptive_rmsnorm = False,
1184
+ use_conditioned_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
1163
1185
  dim_condition = None,
1164
1186
  adaptive_condition_mlp = False,
1165
1187
  adaptive_condition_mlp_expansion = 4,
@@ -1269,7 +1291,7 @@ class AttentionLayers(Module):
1269
1291
 
1270
1292
  assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
1271
1293
 
1272
- need_condition = False
1294
+ norm_need_condition = False
1273
1295
  dim_condition = default(dim_condition, dim)
1274
1296
  dim_condition_mult = 1
1275
1297
 
@@ -1283,25 +1305,17 @@ class AttentionLayers(Module):
1283
1305
  elif use_simple_rmsnorm:
1284
1306
  norm_class = SimpleRMSNorm
1285
1307
  elif use_adaptive_layernorm:
1286
- need_condition = True
1308
+ norm_need_condition = True
1287
1309
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
1288
1310
  elif use_adaptive_rmsnorm:
1289
- need_condition = True
1311
+ norm_need_condition = True
1290
1312
  norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
1291
1313
  else:
1292
1314
  norm_class = LayerNorm
1293
1315
 
1294
1316
  norm_fn = partial(norm_class, dim)
1295
1317
 
1296
- self.adaptive_mlp = nn.Identity()
1297
-
1298
- if need_condition and adaptive_condition_mlp:
1299
- self.adaptive_mlp = nn.Sequential(
1300
- nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1301
- nn.SiLU()
1302
- )
1303
-
1304
- self.need_condition = need_condition
1318
+ self.norm_need_condition = norm_need_condition
1305
1319
  self.dim_condition = dim_condition
1306
1320
 
1307
1321
  # determine default block layer type order
@@ -1318,10 +1332,30 @@ class AttentionLayers(Module):
1318
1332
 
1319
1333
  # determine post branch wrapper
1320
1334
 
1335
+ assert at_most_one_of(use_layerscale, use_conditioned_layerscale)
1336
+
1321
1337
  post_branch_fn = None
1338
+ post_branch_fn_needs_condition = False
1322
1339
 
1323
1340
  if use_layerscale:
1324
1341
  post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
1342
+ elif use_conditioned_layerscale:
1343
+ post_branch_fn = partial(ConditionedLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
1344
+ post_branch_fn_needs_condition = True
1345
+
1346
+ self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
1347
+
1348
+ # setup mlp for conditioning
1349
+
1350
+ self.need_condition = norm_need_condition or post_branch_fn_needs_condition
1351
+
1352
+ self.adaptive_mlp = nn.Identity()
1353
+
1354
+ if self.need_condition and adaptive_condition_mlp:
1355
+ self.adaptive_mlp = nn.Sequential(
1356
+ nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1357
+ nn.SiLU()
1358
+ )
1325
1359
 
1326
1360
  # zero init
1327
1361
 
@@ -1455,24 +1489,32 @@ class AttentionLayers(Module):
1455
1489
  assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
1456
1490
  assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
1457
1491
 
1458
- # setup maybe layernorm kwarg
1492
+ # handle condition
1459
1493
 
1460
- norm_kwargs = dict()
1494
+ if exists(condition):
1495
+ assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1461
1496
 
1462
- if self.need_condition:
1463
1497
  assert condition.ndim in {2, 3}
1464
1498
 
1465
1499
  if condition.ndim == 2:
1466
1500
  condition = rearrange(condition, 'b d -> b 1 d')
1467
1501
 
1468
- assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1502
+ condition = self.adaptive_mlp(condition)
1469
1503
 
1470
- # maybe mlp
1504
+ # setup maybe layernorm kwarg
1471
1505
 
1472
- condition = self.adaptive_mlp(condition)
1506
+ norm_kwargs = dict()
1473
1507
 
1508
+ if self.norm_need_condition:
1474
1509
  norm_kwargs.update(condition = condition)
1475
1510
 
1511
+ # maybe post branch fn conditioning (DiT paper's ada-ln-zero)
1512
+
1513
+ block_forward_kwargs = dict()
1514
+
1515
+ if self.post_branch_fn_needs_condition:
1516
+ block_forward_kwargs.update(condition = condition)
1517
+
1476
1518
  # initialize accums
1477
1519
 
1478
1520
  hiddens = []
@@ -1573,6 +1615,8 @@ class AttentionLayers(Module):
1573
1615
  if layer_type == 'a' and exists(layer_mem):
1574
1616
  layer_mem = pre_norm(layer_mem)
1575
1617
 
1618
+ block = partial(block, **block_forward_kwargs)
1619
+
1576
1620
  if layer_type == 'a':
1577
1621
  out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
1578
1622
  elif layer_type == 'c':
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.20
3
+ Version: 1.30.21
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=kd0H1tsw3SynfQu7xjuzacnuYimVhVMVxLID_I_pM8A,72322
7
+ x_transformers/x_transformers.py,sha256=YFycvL28Y-Y-sr-_P2umC9qZ0nAfKFIBR_5mI1oKog0,73974
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.20.dist-info/METADATA,sha256=OwTLO7xb31tvGd5vHK1XAU70ireyNSlrcDlg2zUJUuY,662
12
- x_transformers-1.30.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.20.dist-info/RECORD,,
10
+ x_transformers-1.30.21.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.21.dist-info/METADATA,sha256=7HN3_fnSG0wGIGijv85bqfd8w4_K1m10Pv_PmI8kNrk,662
12
+ x_transformers-1.30.21.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
13
+ x_transformers-1.30.21.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.21.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.43.0)
2
+ Generator: setuptools (70.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5