x-transformers 1.30.17__py3-none-any.whl → 1.30.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
584
584
 
585
585
  nn.init.zeros_(self.to_gamma.weight)
586
586
 
587
- def forward(self, x, condition):
587
+ def forward(self, x, *, condition):
588
588
  normed = self.ln(x)
589
589
  gamma = self.to_gamma(condition)
590
590
  return normed * (gamma + 1.)
@@ -607,6 +607,20 @@ class RMSNorm(Module):
607
607
  def forward(self, x):
608
608
  return F.normalize(x, dim = -1) * self.scale * self.g
609
609
 
610
+ class AdaptiveRMSNorm(Module):
611
+ def __init__(self, dim, dim_condition = None):
612
+ super().__init__()
613
+ self.scale = dim ** 0.5
614
+ dim_condition = default(dim_condition, dim)
615
+
616
+ self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
617
+ nn.init.zeros_(self.to_gamma.weight)
618
+
619
+ def forward(self, x, *, condition):
620
+ normed = F.normalize(x, dim = -1)
621
+ gamma = self.to_gamma(condition)
622
+ return normed * self.scale * (gamma + 1.)
623
+
610
624
  class SimpleRMSNorm(Module):
611
625
  def __init__(self, dim):
612
626
  super().__init__()
@@ -1145,7 +1159,10 @@ class AttentionLayers(Module):
1145
1159
  use_rmsnorm = False,
1146
1160
  use_simple_rmsnorm = False,
1147
1161
  use_adaptive_layernorm = False,
1162
+ use_adaptive_rmsnorm = False,
1148
1163
  dim_condition = None,
1164
+ adaptive_condition_mlp = False,
1165
+ adaptive_condition_mlp_expansion = 4,
1149
1166
  alibi_pos_bias = False,
1150
1167
  alibi_num_heads = None,
1151
1168
  rel_pos_bias = False,
@@ -1250,11 +1267,22 @@ class AttentionLayers(Module):
1250
1267
 
1251
1268
  # determine norm
1252
1269
 
1253
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1270
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
1254
1271
 
1255
1272
  need_condition = False
1256
1273
  dim_condition = default(dim_condition, dim)
1257
1274
 
1275
+ self.adaptive_mlp = nn.Identity()
1276
+ dim_condition_mult = 1
1277
+
1278
+ if adaptive_condition_mlp:
1279
+ dim_condition_mult = adaptive_condition_mlp_expansion
1280
+
1281
+ self.adaptive_mlp = nn.Sequential(
1282
+ nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1283
+ nn.SiLU()
1284
+ )
1285
+
1258
1286
  if use_scalenorm:
1259
1287
  norm_class = ScaleNorm
1260
1288
  elif use_rmsnorm:
@@ -1263,7 +1291,10 @@ class AttentionLayers(Module):
1263
1291
  norm_class = SimpleRMSNorm
1264
1292
  elif use_adaptive_layernorm:
1265
1293
  need_condition = True
1266
- norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1294
+ norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
1295
+ elif use_adaptive_rmsnorm:
1296
+ need_condition = True
1297
+ norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
1267
1298
  else:
1268
1299
  norm_class = LayerNorm
1269
1300
 
@@ -1428,7 +1459,16 @@ class AttentionLayers(Module):
1428
1459
  norm_kwargs = dict()
1429
1460
 
1430
1461
  if self.need_condition:
1431
- assert condition.shape[-1] == self.dim_condition
1462
+ assert condition.ndim in {2, 3}
1463
+
1464
+ if condition.ndim == 2:
1465
+ condition = rearrange(condition, 'b d -> b 1 d')
1466
+
1467
+ assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1468
+
1469
+ # maybe mlp
1470
+
1471
+ condition = self.adaptive_mlp(condition)
1432
1472
 
1433
1473
  norm_kwargs.update(condition = condition)
1434
1474
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.17
3
+ Version: 1.30.19
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=kvGX4Ib1gSYV6pmVXF6P9bcutRsx5bif_XhkbG4DOZ8,70738
7
+ x_transformers/x_transformers.py,sha256=NxNSAMgyHZEE-toXNcEO-unGwKfoqooN9oEQMDPPGfY,72268
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.17.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.17.dist-info/METADATA,sha256=m_d5lvKUbiN8xS7Dx4gI5I8dHtzEa1ccp4MuKcG5O9w,662
12
- x_transformers-1.30.17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.17.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.17.dist-info/RECORD,,
10
+ x_transformers-1.30.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.19.dist-info/METADATA,sha256=2fQP9SmEys1grPsY2xHhB3sqkODtlaIz4UWtOR9pCM8,662
12
+ x_transformers-1.30.19.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.19.dist-info/RECORD,,