x-transformers 1.21.3__py3-none-any.whl → 1.21.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -952,6 +952,8 @@ class AttentionLayers(nn.Module):
952
952
  custom_layers = None,
953
953
  sandwich_coef = None,
954
954
  par_ratio = None,
955
+ weight_tie_layers = False, # Albert - https://arxiv.org/abs/1909.11942
956
+ layers_execute_order = None, # generalizes weight tying, can do arbitrary layer execution orders
955
957
  residual_attn = False,
956
958
  cross_residual_attn = False,
957
959
  macaron = False,
@@ -1057,6 +1059,15 @@ class AttentionLayers(nn.Module):
1057
1059
  attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
1058
1060
  ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
1059
1061
 
1062
+ # setup weight tying, which is a special case of `layer_execute_order`
1063
+
1064
+ assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1065
+
1066
+ if weight_tie_layers:
1067
+ assert not exists(layers_execute_order)
1068
+ layers_execute_order = tuple(range(len(default_block))) * depth
1069
+ depth = 1
1070
+
1060
1071
  # calculate layer block order
1061
1072
 
1062
1073
  if exists(custom_layers):
@@ -1079,6 +1090,10 @@ class AttentionLayers(nn.Module):
1079
1090
  layer_types = default_block * depth
1080
1091
 
1081
1092
  self.layer_types = layer_types
1093
+ self.layers_execute_order = default(layers_execute_order, tuple(range(len(layer_types))))
1094
+
1095
+ assert all([i < len(self.layer_types) for i in self.layers_execute_order])
1096
+
1082
1097
  self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1083
1098
 
1084
1099
  # stochastic depth
@@ -1187,7 +1202,19 @@ class AttentionLayers(nn.Module):
1187
1202
 
1188
1203
  outer_residual = x * self.resi_dual_scale
1189
1204
 
1190
- for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
1205
+ # get layers to be executed
1206
+
1207
+ layer_variables = (
1208
+ self.layer_types,
1209
+ self.layers,
1210
+ self.layer_dropouts
1211
+ )
1212
+
1213
+ layer_variables = tuple(tuple(layer_variable[i] for i in self.layers_execute_order) for layer_variable in layer_variables)
1214
+
1215
+ # go through the attention and feedforward layers
1216
+
1217
+ for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
1191
1218
  is_last = ind == (len(self.layers) - 1)
1192
1219
 
1193
1220
  if self.training and layer_dropout > 0. and random() < layer_dropout:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.21.3
3
+ Version: 1.21.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -3,10 +3,10 @@ x_transformers/attend.py,sha256=_LepMOwph_o3jio6tur8LEUBPM-2YIn7NpuBmqhU47E,1238
3
3
  x_transformers/autoregressive_wrapper.py,sha256=djA4nfE6_92SAzX1JI0KaC7krdLz1mvnZlaVOaerHDg,5372
4
4
  x_transformers/continuous_autoregressive_wrapper.py,sha256=pTiDqu6JRUlnQJQp_xHATYHy0lgSd6ERLqyiFO3pC-4,1575
5
5
  x_transformers/nonautoregressive_wrapper.py,sha256=AQLE4rA_Kh8VNoe9OzpwyeWson34sRkhks4dn4seNjI,10414
6
- x_transformers/x_transformers.py,sha256=NIp-rnED6zAgOUqeOdDh4GzbCz4RbebYXtBVemT7EMc,57031
6
+ x_transformers/x_transformers.py,sha256=vss6ISABCV74wgEapIt8nPK50j9-QU54hugRiaBh-sw,58088
7
7
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
8
- x_transformers-1.21.3.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
- x_transformers-1.21.3.dist-info/METADATA,sha256=MX9qPfk5TEOp1xIuVqNkb3fjw-naoR_LEFt6pqQ8J88,661
10
- x_transformers-1.21.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
- x_transformers-1.21.3.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
- x_transformers-1.21.3.dist-info/RECORD,,
8
+ x_transformers-1.21.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
9
+ x_transformers-1.21.4.dist-info/METADATA,sha256=HzzezgisQhEH2H6D2tI-JDqNghpuUi6pDlHg0AI976U,661
10
+ x_transformers-1.21.4.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
11
+ x_transformers-1.21.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
12
+ x_transformers-1.21.4.dist-info/RECORD,,