x-transformers 1.28.5__py3-none-any.whl → 1.29.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +7 -1
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.0.dist-info}/METADATA +1 -1
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.0.dist-info}/RECORD +6 -6
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.0.dist-info}/LICENSE +0 -0
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.0.dist-info}/WHEEL +0 -0
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.0.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1001,7 +1001,7 @@ class AttentionLayers(Module):
|
|
1001
1001
|
def __init__(
|
1002
1002
|
self,
|
1003
1003
|
dim,
|
1004
|
-
depth,
|
1004
|
+
depth = None,
|
1005
1005
|
heads = 8,
|
1006
1006
|
causal = False,
|
1007
1007
|
cross_attend = False,
|
@@ -1054,6 +1054,8 @@ class AttentionLayers(Module):
|
|
1054
1054
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
1055
1055
|
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
1056
1056
|
|
1057
|
+
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1058
|
+
|
1057
1059
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1058
1060
|
|
1059
1061
|
self.dim = dim
|
@@ -1138,9 +1140,12 @@ class AttentionLayers(Module):
|
|
1138
1140
|
|
1139
1141
|
# setup weight tying, which is a special case of `layer_execute_order`
|
1140
1142
|
|
1143
|
+
assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
|
1144
|
+
|
1141
1145
|
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
1142
1146
|
|
1143
1147
|
if weight_tie_layers:
|
1148
|
+
assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
|
1144
1149
|
assert not exists(layers_execute_order)
|
1145
1150
|
layers_execute_order = tuple(range(len(default_block))) * depth
|
1146
1151
|
depth = 1
|
@@ -1164,6 +1169,7 @@ class AttentionLayers(Module):
|
|
1164
1169
|
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
1165
1170
|
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
1166
1171
|
else:
|
1172
|
+
assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
|
1167
1173
|
layer_types = default_block * depth
|
1168
1174
|
|
1169
1175
|
self.layer_types = layer_types
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=ub1QXJIXfoK5Bm8poZ1oJC99hbt9QitAuKmmmfBtxUY,65111
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.29.0.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.29.0.dist-info/METADATA,sha256=6ivD0nnIvXz057mJdIeHYNt2s9E0fN69eqSPGtSbcXg,661
|
12
|
+
x_transformers-1.29.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.29.0.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.29.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|