x-transformers 1.30.21__py3-none-any.whl → 1.30.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -9
- {x_transformers-1.30.21.dist-info → x_transformers-1.30.23.dist-info}/METADATA +1 -1
- {x_transformers-1.30.21.dist-info → x_transformers-1.30.23.dist-info}/RECORD +6 -6
- {x_transformers-1.30.21.dist-info → x_transformers-1.30.23.dist-info}/LICENSE +0 -0
- {x_transformers-1.30.21.dist-info → x_transformers-1.30.23.dist-info}/WHEEL +0 -0
- {x_transformers-1.30.21.dist-info → x_transformers-1.30.23.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -710,7 +710,7 @@ class LayerScale(Module):
|
|
710
710
|
out, *rest = out
|
711
711
|
return out * self.gamma, *rest
|
712
712
|
|
713
|
-
class
|
713
|
+
class AdaptiveLayerScale(Module):
|
714
714
|
def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
|
715
715
|
super().__init__()
|
716
716
|
self.fn = fn
|
@@ -1181,7 +1181,7 @@ class AttentionLayers(Module):
|
|
1181
1181
|
use_simple_rmsnorm = False,
|
1182
1182
|
use_adaptive_layernorm = False,
|
1183
1183
|
use_adaptive_rmsnorm = False,
|
1184
|
-
|
1184
|
+
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
|
1185
1185
|
dim_condition = None,
|
1186
1186
|
adaptive_condition_mlp = False,
|
1187
1187
|
adaptive_condition_mlp_expansion = 4,
|
@@ -1201,8 +1201,8 @@ class AttentionLayers(Module):
|
|
1201
1201
|
rotary_xpos_scale_base = 512,
|
1202
1202
|
rotary_base_rescale_factor = 1.,
|
1203
1203
|
weight_tie_layers = False,
|
1204
|
-
custom_layers: Tuple[str] | None = None,
|
1205
|
-
layers_execute_order: Tuple[int] | None = None,
|
1204
|
+
custom_layers: Tuple[str, ...] | None = None,
|
1205
|
+
layers_execute_order: Tuple[int, ...] | None = None,
|
1206
1206
|
sandwich_coef = None,
|
1207
1207
|
par_ratio = None,
|
1208
1208
|
residual_attn = False,
|
@@ -1332,15 +1332,15 @@ class AttentionLayers(Module):
|
|
1332
1332
|
|
1333
1333
|
# determine post branch wrapper
|
1334
1334
|
|
1335
|
-
assert at_most_one_of(use_layerscale,
|
1335
|
+
assert at_most_one_of(use_layerscale, use_adaptive_layerscale)
|
1336
1336
|
|
1337
1337
|
post_branch_fn = None
|
1338
1338
|
post_branch_fn_needs_condition = False
|
1339
1339
|
|
1340
1340
|
if use_layerscale:
|
1341
1341
|
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
|
1342
|
-
elif
|
1343
|
-
post_branch_fn = partial(
|
1342
|
+
elif use_adaptive_layerscale:
|
1343
|
+
post_branch_fn = partial(AdaptiveLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
|
1344
1344
|
post_branch_fn_needs_condition = True
|
1345
1345
|
|
1346
1346
|
self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
|
@@ -1484,7 +1484,8 @@ class AttentionLayers(Module):
|
|
1484
1484
|
cache_age = 1,
|
1485
1485
|
return_hiddens = False,
|
1486
1486
|
rotary_pos_emb = None,
|
1487
|
-
condition = None
|
1487
|
+
condition = None,
|
1488
|
+
layers_execute_order: Tuple[int, ...] | None = None
|
1488
1489
|
):
|
1489
1490
|
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
|
1490
1491
|
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'
|
@@ -1576,7 +1577,11 @@ class AttentionLayers(Module):
|
|
1576
1577
|
self.layer_dropouts
|
1577
1578
|
)
|
1578
1579
|
|
1579
|
-
|
1580
|
+
# able to override the layers execution order on forward, for trying to depth extrapolate
|
1581
|
+
|
1582
|
+
layers_execute_order = default(layers_execute_order, self.layers_execute_order)
|
1583
|
+
|
1584
|
+
layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
|
1580
1585
|
|
1581
1586
|
# go through the attention and feedforward layers
|
1582
1587
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=tZixUvlsaEj3CpB49KLDOJ2BwYSPjdWotDUjB9Rbf7g,74213
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.30.
|
11
|
-
x_transformers-1.30.
|
12
|
-
x_transformers-1.30.
|
13
|
-
x_transformers-1.30.
|
14
|
-
x_transformers-1.30.
|
10
|
+
x_transformers-1.30.23.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.30.23.dist-info/METADATA,sha256=LM8Y0bkOF259zCn_FE2A-Uw5Yjr8YrqCKNYuW4DqtQY,662
|
12
|
+
x_transformers-1.30.23.dist-info/WHEEL,sha256=cpQTJ5IWu9CdaPViMhC9YzF8gZuS5-vlfoFihTBC86A,91
|
13
|
+
x_transformers-1.30.23.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.30.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|