x-transformers 1.30.11__py3-none-any.whl → 1.30.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +20 -3
- {x_transformers-1.30.11.dist-info → x_transformers-1.30.14.dist-info}/METADATA +1 -1
- {x_transformers-1.30.11.dist-info → x_transformers-1.30.14.dist-info}/RECORD +6 -6
- {x_transformers-1.30.11.dist-info → x_transformers-1.30.14.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.11.dist-info → x_transformers-1.30.14.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.11.dist-info → x_transformers-1.30.14.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -80,6 +80,13 @@ class equals():
|
|
80
80
|
def __call__(self, x, *args, **kwargs):
|
81
81
|
return x == self.val
|
82
82
|
|
83
|
+
class Identity(Module):
|
84
|
+
def __init__(self, *args, **kwargs):
|
85
|
+
super().__init__()
|
86
|
+
|
87
|
+
def __call__(self, x, *args, **kwargs):
|
88
|
+
return x
|
89
|
+
|
83
90
|
def Sequential(*modules):
|
84
91
|
return nn.Sequential(*filter(exists, modules))
|
85
92
|
|
@@ -1129,6 +1136,7 @@ class AttentionLayers(Module):
|
|
1129
1136
|
use_scalenorm = False,
|
1130
1137
|
use_rmsnorm = False,
|
1131
1138
|
use_simple_rmsnorm = False,
|
1139
|
+
no_pre_or_postnorm = False,
|
1132
1140
|
alibi_pos_bias = False,
|
1133
1141
|
alibi_num_heads = None,
|
1134
1142
|
rel_pos_bias = False,
|
@@ -1358,9 +1366,18 @@ class AttentionLayers(Module):
|
|
1358
1366
|
residual_fn = GRUGating if gate_residual else Residual
|
1359
1367
|
residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant)
|
1360
1368
|
|
1361
|
-
|
1362
|
-
|
1363
|
-
|
1369
|
+
# for the peri-layernorm config from https://arxiv.org/abs/2405.16039
|
1370
|
+
# must be paired with qk norm
|
1371
|
+
|
1372
|
+
layer_norm_fn = norm_fn
|
1373
|
+
if no_pre_or_postnorm:
|
1374
|
+
layer_norm_fn = Identity
|
1375
|
+
|
1376
|
+
# all normalizations of the layer
|
1377
|
+
|
1378
|
+
pre_branch_norm = layer_norm_fn() if pre_norm else None
|
1379
|
+
post_branch_norm = layer_norm_fn() if sandwich_norm else None
|
1380
|
+
post_main_norm = layer_norm_fn() if not pre_norm else None
|
1364
1381
|
|
1365
1382
|
norms = ModuleList([
|
1366
1383
|
pre_branch_norm,
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=yCLWe5g0iFDBE0lSnoB9Up0bYHzfIFL0f7zhNTbuoS0,69513
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.14.dist-info/METADATA,sha256=0Y4a5bhVub1knobIGu7EnygkAGeiXV7EviJDkDaktBg,662
|
12
|
+
x_transformers-1.30.14.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.30.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|