x-transformers 1.30.18__py3-none-any.whl → 1.30.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -584,7 +584,7 @@ class AdaptiveLayerNorm(Module):
584
584
 
585
585
  nn.init.zeros_(self.to_gamma.weight)
586
586
 
587
- def forward(self, x, condition):
587
+ def forward(self, x, *, condition):
588
588
  normed = self.ln(x)
589
589
  gamma = self.to_gamma(condition)
590
590
  return normed * (gamma + 1.)
@@ -616,7 +616,7 @@ class AdaptiveRMSNorm(Module):
616
616
  self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
617
617
  nn.init.zeros_(self.to_gamma.weight)
618
618
 
619
- def forward(self, x, condition):
619
+ def forward(self, x, *, condition):
620
620
  normed = F.normalize(x, dim = -1)
621
621
  gamma = self.to_gamma(condition)
622
622
  return normed * self.scale * (gamma + 1.)
@@ -1161,6 +1161,8 @@ class AttentionLayers(Module):
1161
1161
  use_adaptive_layernorm = False,
1162
1162
  use_adaptive_rmsnorm = False,
1163
1163
  dim_condition = None,
1164
+ adaptive_condition_mlp = False,
1165
+ adaptive_condition_mlp_expansion = 4,
1164
1166
  alibi_pos_bias = False,
1165
1167
  alibi_num_heads = None,
1166
1168
  rel_pos_bias = False,
@@ -1269,6 +1271,10 @@ class AttentionLayers(Module):
1269
1271
 
1270
1272
  need_condition = False
1271
1273
  dim_condition = default(dim_condition, dim)
1274
+ dim_condition_mult = 1
1275
+
1276
+ if adaptive_condition_mlp:
1277
+ dim_condition_mult = adaptive_condition_mlp_expansion
1272
1278
 
1273
1279
  if use_scalenorm:
1274
1280
  norm_class = ScaleNorm
@@ -1278,15 +1284,23 @@ class AttentionLayers(Module):
1278
1284
  norm_class = SimpleRMSNorm
1279
1285
  elif use_adaptive_layernorm:
1280
1286
  need_condition = True
1281
- norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1287
+ norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
1282
1288
  elif use_adaptive_rmsnorm:
1283
1289
  need_condition = True
1284
- norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
1290
+ norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
1285
1291
  else:
1286
1292
  norm_class = LayerNorm
1287
1293
 
1288
1294
  norm_fn = partial(norm_class, dim)
1289
1295
 
1296
+ self.adaptive_mlp = nn.Identity()
1297
+
1298
+ if need_condition and adaptive_condition_mlp:
1299
+ self.adaptive_mlp = nn.Sequential(
1300
+ nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
1301
+ nn.SiLU()
1302
+ )
1303
+
1290
1304
  self.need_condition = need_condition
1291
1305
  self.dim_condition = dim_condition
1292
1306
 
@@ -1446,7 +1460,16 @@ class AttentionLayers(Module):
1446
1460
  norm_kwargs = dict()
1447
1461
 
1448
1462
  if self.need_condition:
1449
- assert condition.shape[-1] == self.dim_condition
1463
+ assert condition.ndim in {2, 3}
1464
+
1465
+ if condition.ndim == 2:
1466
+ condition = rearrange(condition, 'b d -> b 1 d')
1467
+
1468
+ assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
1469
+
1470
+ # maybe mlp
1471
+
1472
+ condition = self.adaptive_mlp(condition)
1450
1473
 
1451
1474
  norm_kwargs.update(condition = condition)
1452
1475
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.18
3
+ Version: 1.30.20
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
7
+ x_transformers/x_transformers.py,sha256=kd0H1tsw3SynfQu7xjuzacnuYimVhVMVxLID_I_pM8A,72322
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
12
- x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.18.dist-info/RECORD,,
10
+ x_transformers-1.30.20.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.20.dist-info/METADATA,sha256=OwTLO7xb31tvGd5vHK1XAU70ireyNSlrcDlg2zUJUuY,662
12
+ x_transformers-1.30.20.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.20.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.20.dist-info/RECORD,,