x-transformers 1.28.5__py3-none-any.whl → 1.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +14 -2
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.1.dist-info}/METADATA +1 -1
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.1.dist-info}/RECORD +6 -6
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.1.dist-info}/LICENSE +0 -0
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.1.dist-info}/WHEEL +0 -0
- {x_transformers-1.28.5.dist-info → x_transformers-1.29.1.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1001,7 +1001,7 @@ class AttentionLayers(Module):
|
|
1001
1001
|
def __init__(
|
1002
1002
|
self,
|
1003
1003
|
dim,
|
1004
|
-
depth,
|
1004
|
+
depth = None,
|
1005
1005
|
heads = 8,
|
1006
1006
|
causal = False,
|
1007
1007
|
cross_attend = False,
|
@@ -1054,10 +1054,11 @@ class AttentionLayers(Module):
|
|
1054
1054
|
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
|
1055
1055
|
cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
|
1056
1056
|
|
1057
|
+
assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
|
1058
|
+
|
1057
1059
|
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
1058
1060
|
|
1059
1061
|
self.dim = dim
|
1060
|
-
self.depth = depth
|
1061
1062
|
self.causal = causal
|
1062
1063
|
self.layers = ModuleList([])
|
1063
1064
|
|
@@ -1138,9 +1139,12 @@ class AttentionLayers(Module):
|
|
1138
1139
|
|
1139
1140
|
# setup weight tying, which is a special case of `layer_execute_order`
|
1140
1141
|
|
1142
|
+
assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
|
1143
|
+
|
1141
1144
|
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
1142
1145
|
|
1143
1146
|
if weight_tie_layers:
|
1147
|
+
assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
|
1144
1148
|
assert not exists(layers_execute_order)
|
1145
1149
|
layers_execute_order = tuple(range(len(default_block))) * depth
|
1146
1150
|
depth = 1
|
@@ -1164,6 +1168,7 @@ class AttentionLayers(Module):
|
|
1164
1168
|
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
1165
1169
|
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
1166
1170
|
else:
|
1171
|
+
assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
|
1167
1172
|
layer_types = default_block * depth
|
1168
1173
|
|
1169
1174
|
self.layer_types = layer_types
|
@@ -1173,6 +1178,13 @@ class AttentionLayers(Module):
|
|
1173
1178
|
|
1174
1179
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
1175
1180
|
|
1181
|
+
# validate and set the depth
|
1182
|
+
|
1183
|
+
depth = default(depth, len(self.layers_execute_order))
|
1184
|
+
assert depth == len(self.layers_execute_order)
|
1185
|
+
|
1186
|
+
self.depth = depth
|
1187
|
+
|
1176
1188
|
# stochastic depth
|
1177
1189
|
|
1178
1190
|
self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
|
|
4
4
|
x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=jj87ALpQpHGgvG1oHn4Z6UDmc1pqkoO6dY7YtY038w8,65269
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
|
10
|
-
x_transformers-1.
|
11
|
-
x_transformers-1.
|
12
|
-
x_transformers-1.
|
13
|
-
x_transformers-1.
|
14
|
-
x_transformers-1.
|
10
|
+
x_transformers-1.29.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.29.1.dist-info/METADATA,sha256=4Nnxc5THUI-d21Szj2mPLTlZYF0A9xVjHN4laFiLCIE,661
|
12
|
+
x_transformers-1.29.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
13
|
+
x_transformers-1.29.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.29.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|