x-transformers 1.30.17__py3-none-any.whl → 1.30.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -607,6 +607,20 @@ class RMSNorm(Module):
607
607
  def forward(self, x):
608
608
  return F.normalize(x, dim = -1) * self.scale * self.g
609
609
 
610
+ class AdaptiveRMSNorm(Module):
611
+ def __init__(self, dim, dim_condition = None):
612
+ super().__init__()
613
+ self.scale = dim ** 0.5
614
+ dim_condition = default(dim_condition, dim)
615
+
616
+ self.to_gamma = nn.Linear(dim_condition, dim, bias = False)
617
+ nn.init.zeros_(self.to_gamma.weight)
618
+
619
+ def forward(self, x, condition):
620
+ normed = F.normalize(x, dim = -1)
621
+ gamma = self.to_gamma(condition)
622
+ return normed * self.scale * (gamma + 1.)
623
+
610
624
  class SimpleRMSNorm(Module):
611
625
  def __init__(self, dim):
612
626
  super().__init__()
@@ -1145,6 +1159,7 @@ class AttentionLayers(Module):
1145
1159
  use_rmsnorm = False,
1146
1160
  use_simple_rmsnorm = False,
1147
1161
  use_adaptive_layernorm = False,
1162
+ use_adaptive_rmsnorm = False,
1148
1163
  dim_condition = None,
1149
1164
  alibi_pos_bias = False,
1150
1165
  alibi_num_heads = None,
@@ -1250,7 +1265,7 @@ class AttentionLayers(Module):
1250
1265
 
1251
1266
  # determine norm
1252
1267
 
1253
- assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm'
1268
+ assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'
1254
1269
 
1255
1270
  need_condition = False
1256
1271
  dim_condition = default(dim_condition, dim)
@@ -1264,6 +1279,9 @@ class AttentionLayers(Module):
1264
1279
  elif use_adaptive_layernorm:
1265
1280
  need_condition = True
1266
1281
  norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition)
1282
+ elif use_adaptive_rmsnorm:
1283
+ need_condition = True
1284
+ norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition)
1267
1285
  else:
1268
1286
  norm_class = LayerNorm
1269
1287
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.17
3
+ Version: 1.30.18
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=kvGX4Ib1gSYV6pmVXF6P9bcutRsx5bif_XhkbG4DOZ8,70738
7
+ x_transformers/x_transformers.py,sha256=8XuiUXFOD7KAmopmf66mCq-HRs1g5Wd5tHcTTpm9JeM,71460
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.17.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.17.dist-info/METADATA,sha256=m_d5lvKUbiN8xS7Dx4gI5I8dHtzEa1ccp4MuKcG5O9w,662
12
- x_transformers-1.30.17.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.17.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.17.dist-info/RECORD,,
10
+ x_transformers-1.30.18.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.18.dist-info/METADATA,sha256=6aA6OcLnBMlxZKJeRShv9UMT1BUYZlq-jfrj68nv5yU,662
12
+ x_transformers-1.30.18.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.18.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.18.dist-info/RECORD,,