x-transformers 1.28.5__py3-none-any.whl → 1.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1001,7 +1001,7 @@ class AttentionLayers(Module):
1001
1001
  def __init__(
1002
1002
  self,
1003
1003
  dim,
1004
- depth,
1004
+ depth = None,
1005
1005
  heads = 8,
1006
1006
  causal = False,
1007
1007
  cross_attend = False,
@@ -1054,10 +1054,11 @@ class AttentionLayers(Module):
1054
1054
  attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)
1055
1055
  cross_attn_kwargs, kwargs = groupby_prefix_and_trim('cross_attn_', kwargs)
1056
1056
 
1057
+ assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}'
1058
+
1057
1059
  dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
1058
1060
 
1059
1061
  self.dim = dim
1060
- self.depth = depth
1061
1062
  self.causal = causal
1062
1063
  self.layers = ModuleList([])
1063
1064
 
@@ -1138,9 +1139,12 @@ class AttentionLayers(Module):
1138
1139
 
1139
1140
  # setup weight tying, which is a special case of `layer_execute_order`
1140
1141
 
1142
+ assert not (exists(layers_execute_order) and exists(custom_layers) and exists(depth)), 'depth should not be passed in if using custom layers and custom layer execution order'
1143
+
1141
1144
  assert not (weight_tie_layers and any([*map(exists, (custom_layers, par_ratio, sandwich_coef))]))
1142
1145
 
1143
1146
  if weight_tie_layers:
1147
+ assert exists(depth), 'depth must be passed in with `weight_tie_layers` = True'
1144
1148
  assert not exists(layers_execute_order)
1145
1149
  layers_execute_order = tuple(range(len(default_block))) * depth
1146
1150
  depth = 1
@@ -1164,6 +1168,7 @@ class AttentionLayers(Module):
1164
1168
  assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
1165
1169
  layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
1166
1170
  else:
1171
+ assert exists(depth), '`depth` must be passed in for `Decoder` or `Encoder`'
1167
1172
  layer_types = default_block * depth
1168
1173
 
1169
1174
  self.layer_types = layer_types
@@ -1173,6 +1178,13 @@ class AttentionLayers(Module):
1173
1178
 
1174
1179
  self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
1175
1180
 
1181
+ # validate and set the depth
1182
+
1183
+ depth = default(depth, len(self.layers_execute_order))
1184
+ assert depth == len(self.layers_execute_order)
1185
+
1186
+ self.depth = depth
1187
+
1176
1188
  # stochastic depth
1177
1189
 
1178
1190
  self.layer_dropouts = cast_tuple(layer_dropout, len(layer_types))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.28.5
3
+ Version: 1.29.1
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=gYKIN5Rm8dMYSTX5yHpg9sPYyZf9rsRT
4
4
  x_transformers/continuous.py,sha256=dpHK4NSMDQAJQ_N3Uj9rip0fYGXyu0QCCO_OfEdbRGs,6192
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=w_S0zOCKJtAO2M5ZKdE7gqSUWzkqECkA87ah-vkqx0Y,64656
7
+ x_transformers/x_transformers.py,sha256=jj87ALpQpHGgvG1oHn4Z6UDmc1pqkoO6dY7YtY038w8,65269
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=EN3hxxleTRGYeAz6i4x3U_PrOm9TjxMF3eDhMKGx59E,8575
10
- x_transformers-1.28.5.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.28.5.dist-info/METADATA,sha256=jLcNekd2_ccREKevcTAtHNAnjwqnxaRmAvq90_eSdQI,661
12
- x_transformers-1.28.5.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
- x_transformers-1.28.5.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.28.5.dist-info/RECORD,,
10
+ x_transformers-1.29.1.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.29.1.dist-info/METADATA,sha256=4Nnxc5THUI-d21Szj2mPLTlZYF0A9xVjHN4laFiLCIE,661
12
+ x_transformers-1.29.1.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
13
+ x_transformers-1.29.1.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.29.1.dist-info/RECORD,,