x-transformers 1.30.11__py3-none-any.whl → 1.30.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -80,6 +80,13 @@ class equals():
80
80
  def __call__(self, x, *args, **kwargs):
81
81
  return x == self.val
82
82
 
83
+ class Identity(Module):
84
+ def __init__(self, *args, **kwargs):
85
+ super().__init__()
86
+
87
+ def __call__(self, x, *args, **kwargs):
88
+ return x
89
+
83
90
  def Sequential(*modules):
84
91
  return nn.Sequential(*filter(exists, modules))
85
92
 
@@ -1129,6 +1136,7 @@ class AttentionLayers(Module):
1129
1136
  use_scalenorm = False,
1130
1137
  use_rmsnorm = False,
1131
1138
  use_simple_rmsnorm = False,
1139
+ no_pre_or_postnorm = False,
1132
1140
  alibi_pos_bias = False,
1133
1141
  alibi_num_heads = None,
1134
1142
  rel_pos_bias = False,
@@ -1358,9 +1366,18 @@ class AttentionLayers(Module):
1358
1366
  residual_fn = GRUGating if gate_residual else Residual
1359
1367
  residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
1360
1368
 
1361
- pre_branch_norm = norm_fn() if pre_norm else None
1362
- post_branch_norm = norm_fn() if sandwich_norm else None
1363
- post_main_norm = norm_fn() if not pre_norm else None
1369
+ # for the peri-layernorm config from https://arxiv.org/abs/2405.16039
1370
+ # must be paired with qk norm
1371
+
1372
+ layer_norm_fn = norm_fn
1373
+ if no_pre_or_postnorm:
1374
+ layer_norm_fn = Identity
1375
+
1376
+ # all normalizations of the layer
1377
+
1378
+ pre_branch_norm = layer_norm_fn() if pre_norm else None
1379
+ post_branch_norm = layer_norm_fn() if sandwich_norm else None
1380
+ post_main_norm = layer_norm_fn() if not pre_norm else None
1364
1381
 
1365
1382
  norms = ModuleList([
1366
1383
  pre_branch_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.30.11
3
+ Version: 1.30.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=97WREYHfGljhrveGNUUSEWk2xbFvKdM52QXW7cnoBpk,69019
7
+ x_transformers/x_transformers.py,sha256=yCLWe5g0iFDBE0lSnoB9Up0bYHzfIFL0f7zhNTbuoS0,69513
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.30.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.30.11.dist-info/METADATA,sha256=kpv8_mEb4DU4zL_oyLwzbWJ7z8WV7mf7c466Hehn-6c,662
12
- x_transformers-1.30.11.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.30.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.30.11.dist-info/RECORD,,
10
+ x_transformers-1.30.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.30.14.dist-info/METADATA,sha256=0Y4a5bhVub1knobIGu7EnygkAGeiXV7EviJDkDaktBg,662
12
+ x_transformers-1.30.14.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.30.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.30.14.dist-info/RECORD,,