x-transformers 1.32.8__py3-none-any.whl → 1.32.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1677,8 +1677,6 @@ class AttentionLayers(Module):
1677
1677
 
1678
1678
  layer_variables = tuple(tuple(layer_variable[i] for i in layers_execute_order) for layer_variable in layer_variables)
1679
1679
 
1680
- first_skip = None
1681
-
1682
1680
  # go through the attention and feedforward layers
1683
1681
 
1684
1682
  for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
@@ -1687,9 +1685,6 @@ class AttentionLayers(Module):
1687
1685
  if self.training and layer_dropout > 0. and random() < layer_dropout:
1688
1686
  continue
1689
1687
 
1690
- if ind == 1:
1691
- first_skip = x.clone()
1692
-
1693
1688
  if layer_type == 'a':
1694
1689
  if return_hiddens:
1695
1690
  hiddens.append(x)
@@ -1925,7 +1920,8 @@ class TransformerWrapper(Module):
1925
1920
  l2norm_embed = False,
1926
1921
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1927
1922
  attn_z_loss_weight = 1e-4,
1928
- average_pool_embed = False
1923
+ average_pool_embed = False,
1924
+ use_cls_token = False,
1929
1925
  ):
1930
1926
  super().__init__()
1931
1927
 
@@ -1971,6 +1967,16 @@ class TransformerWrapper(Module):
1971
1967
 
1972
1968
  assert num_output_heads > 0
1973
1969
 
1970
+ assert at_most_one_of(average_pool_embed, use_cls_token)
1971
+
1972
+ # classic cls token from the bert days
1973
+
1974
+ self.cls_token = None
1975
+
1976
+ if use_cls_token:
1977
+ self.cls_token = nn.Parameter(torch.zeros(dim))
1978
+ nn.init.normal_(self.cls_token, std = 0.02)
1979
+
1974
1980
  # whether to average pool the embed (`global average pool`)
1975
1981
 
1976
1982
  self.average_pool_embed = average_pool_embed
@@ -2097,7 +2103,19 @@ class TransformerWrapper(Module):
2097
2103
 
2098
2104
  x = self.project_emb(x)
2099
2105
 
2106
+ # maybe cls token
2107
+
2108
+ if exists(self.cls_token):
2109
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
2110
+ x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
2111
+
2112
+ if exists(mask):
2113
+ mask = F.pad(mask, (1, 0), value = True)
2114
+
2115
+ # maybe memory / register tokens
2116
+
2100
2117
  if has_memory_tokens:
2118
+ mem_seq = x.shape[-2]
2101
2119
  mem_every = self.memory_tokens_interspersed_every
2102
2120
 
2103
2121
  if exists(mem_every):
@@ -2137,13 +2155,16 @@ class TransformerWrapper(Module):
2137
2155
  if exists(mem_every):
2138
2156
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2139
2157
 
2140
- x = x[:, :n]
2158
+ x = x[:, :mem_seq]
2141
2159
 
2142
2160
  # global average pool
2143
2161
 
2144
2162
  if self.average_pool_embed:
2145
2163
  x = masked_mean(x, mask = orig_mask, dim = 1)
2146
2164
 
2165
+ if exists(self.cls_token):
2166
+ x, _ = unpack(x, cls_packed_shape, 'b * d')
2167
+
2147
2168
  # projecting to logits
2148
2169
 
2149
2170
  if not return_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.8
3
+ Version: 1.32.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=cq364zjUVvGEeFxdu703yI2tp1VhpxTIpLTgMshHpzI,77392
8
+ x_transformers/x_transformers.py,sha256=Ao3yHjEdl-qovGo9WW8q277wBHMgFxYRfcYRf1W_hKg,78076
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.8.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.8.dist-info/METADATA,sha256=tEHQVjqqXKQ2eSd-j5yrmfubkjuQZga_sT_5XgHanQo,661
13
- x_transformers-1.32.8.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.8.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.8.dist-info/RECORD,,
11
+ x_transformers-1.32.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.10.dist-info/METADATA,sha256=DMtabf-G60PL6axX1zSsTcWcHtzvHtKQNTxHuzOFJ4A,662
13
+ x_transformers-1.32.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.10.dist-info/RECORD,,