x-transformers 1.21.3__py3-none-any.whl → 1.21.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +28 -1
- {x_transformers-1.21.3.dist-info → x_transformers-1.21.4.dist-info}/METADATA +1 -1
- {x_transformers-1.21.3.dist-info → x_transformers-1.21.4.dist-info}/RECORD +6 -6
- {x_transformers-1.21.3.dist-info → x_transformers-1.21.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.21.3.dist-info → x_transformers-1.21.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.21.3.dist-info → x_transformers-1.21.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -952,6 +952,8 @@ class AttentionLayers(nn.Module):
|
|
952
952
|
custom_layers = None,
|
953
953
|
sandwich_coef = None,
|
954
954
|
par_ratio = None,
|
955
|
+
weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
|
956
|
+
layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
|
955
957
|
residual_attn = False,
|
956
958
|
cross_residual_attn = False,
|
957
959
|
macaron = False,
|
@@ -1057,6 +1059,15 @@ class AttentionLayers(nn.Module):
|
|
1057
1059
|
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
1058
1060
|
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
1059
1061
|
|
1062
|
+
# setup weight tying, which is a special case of `layer_execute_order`
|
1063
|
+
|
1064
|
+
assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
|
1065
|
+
|
1066
|
+
if weight_tie_layers:
|
1067
|
+
assert not exists(layers_execute_order)
|
1068
|
+
layers_execute_order = tuple(range(len(default_block))) * depth
|
1069
|
+
depth = 1
|
1070
|
+
|
1060
1071
|
# calculate layer block order
|
1061
1072
|
|
1062
1073
|
if exists(custom_layers):
|
@@ -1079,6 +1090,10 @@ class AttentionLayers(nn.Module):
|
|
1079
1090
|
layer_types = default_block * depth
|
1080
1091
|
|
1081
1092
|
self.layer_types = layer_types
|
1093
|
+
self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
|
1094
|
+
|
1095
|
+
assert all([i < len(self.layer_types) for i in self.layers_execute_order])
|
1096
|
+
|
1082
1097
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
1083
1098
|
|
1084
1099
|
# stochastic depth
|
@@ -1187,7 +1202,19 @@ class AttentionLayers(nn.Module):
|
|
1187
1202
|
|
1188
1203
|
outer_residual = x * self.resi_dual_scale
|
1189
1204
|
|
1190
|
-
|
1205
|
+
# get layers to be executed
|
1206
|
+
|
1207
|
+
layer_variables = (
|
1208
|
+
self.layer_types,
|
1209
|
+
self.layers,
|
1210
|
+
self.layer_dropouts
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
|
1214
|
+
|
1215
|
+
# go through the attention and feedforward layers
|
1216
|
+
|
1217
|
+
for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
|
1191
1218
|
is_last = ind == (len(self.layers) - 1)
|
1192
1219
|
|
1193
1220
|
if self.training and layer_dropout > 0. and random() < layer_dropout:
|
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=_LepMOwph_o3jio6tur8LEUBPM-2YIn7NpuBmqhU47E,1238
|
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=djA4nfE6_92SAzX1JI0KaC7krdLz1mvnZlaVOaerHDg,5372
|
4
4
|
x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
|
5
5
|
x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
|
6
|
-
x_transformers/x_transformers.py,sha256=
|
6
|
+
x_transformers/x_transformers.py,sha256=vss6ISABCV74wgEapIt8nPK50j9-QU54hugRiaBh-sw,58088
|
7
7
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
8
|
-
x_transformers-1.21.
|
9
|
-
x_transformers-1.21.
|
10
|
-
x_transformers-1.21.
|
11
|
-
x_transformers-1.21.
|
12
|
-
x_transformers-1.21.
|
8
|
+
x_transformers-1.21.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
9
|
+
x_transformers-1.21.4.dist-info/METADATA,sha256=HzzezgisQhEH2H6D2tI-JDqNghpuUi6pDlHg0AI976U,661
|
10
|
+
x_transformers-1.21.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
|
11
|
+
x_transformers-1.21.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
12
|
+
x_transformers-1.21.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|