x-transformers 1.30.18__py3-none-any.whl → 1.30.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
584
584
 
585
585
  nn.init.zeros_(self.to_gamma.weight)
586
586
 
587
- def forward(self, x, condition):
587
+ def forward(self, x, *, condition):
588
588
  normed = self.ln(x)
589
589
  gamma = self.to_gamma(condition)
590
590
  return normed * (gamma + 1.)
@@ -616,7 +616,7 @@ class AdaptiveRMSNorm(Module):
616
616
  self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
617
617
  nn.init.zeros_(self.to_gamma.weight)
618
618
 
619
- def forward(self, x, condition):
619
+ def forward(self, x, *, condition):
620
620
  normed = F.normalize(x, dim = -1)
621
621
  gamma = self.to_gamma(condition)
622
622
  return normed * self.scale * (gamma + 1.)
@@ -1161,6 +1161,8 @@ class AttentionLayers(Module):
1161
1161
  use_adaptive_layernorm = False,
1162
1162
  use_adaptive_rmsnorm = False,
1163
1163
  dim_condition = None,
1164
+ adaptive_condition_mlp = False,
1165
+ adaptive_condition_mlp_expansion = 4,
1164
1166
  alibi_pos_bias = False,
1165
1167
  alibi_num_heads = None,
1166
1168
  rel_pos_bias = False,
@@ -1270,6 +1272,17 @@ class AttentionLayers(Module):
1270
1272
  need_condition = False
1271
1273
  dim_condition = default(dim_condition, dim)
1272
1274
 
1275
+ self.adaptive_mlp = nn.Identity()
1276
+ dim_condition_mult = 1
1277
+
1278
+ if adaptive_condition_mlp:
1279
+ dim_condition_mult = adaptive_condition_mlp_expansion
1280
+
1281
+ self.adaptive_mlp = nn.Sequential(
1282
+ nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1283
+ nn.SiLU()
1284
+ )
1285
+
1273
1286
  if use_scalenorm:
1274
1287
  norm_class = ScaleNorm
1275
1288
  elif use_rmsnorm:
@@ -1278,10 +1291,10 @@ class AttentionLayers(Module):
1278
1291
  norm_class = SimpleRMSNorm
1279
1292
  elif use_adaptive_layernorm:
1280
1293
  need_condition = True
1281
- norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1294
+ norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
1282
1295
  elif use_adaptive_rmsnorm:
1283
1296
  need_condition = True
1284
- norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
1297
+ norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
1285
1298
  else:
1286
1299
  norm_class = LayerNorm
1287
1300
 
@@ -1446,7 +1459,16 @@ class AttentionLayers(Module):
1446
1459
  norm_kwargs = dict()
1447
1460
 
1448
1461
  if self.need_condition:
1449
- assert condition.shape[-1] == self.dim_condition
1462
+ assert condition.ndim in {2, 3}
1463
+
1464
+ if condition.ndim == 2:
1465
+ condition = rearrange(condition, 'b d -> b 1 d')
1466
+
1467
+ assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1468
+
1469
+ # maybe mlp
1470
+
1471
+ condition = self.adaptive_mlp(condition)
1450
1472
 
1451
1473
  norm_kwargs.update(condition = condition)
1452
1474
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.18
3
+ Version: 1.30.19
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
7
+ x_transformers/x_transformers.py,sha256=NxNSAMgyHZEE-toXNcEO-unGwKfoqooN9oEQMDPPGfY,72268
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
12
- x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.18.dist-info/RECORD,,
10
+ x_transformers-1.30.19.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.19.dist-info/METADATA,sha256=2fQP9SmEys1grPsY2xHhB3sqkODtlaIz4UWtOR9pCM8,662
12
+ x_transformers-1.30.19.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.19.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.19.dist-info/RECORD,,