x-transformers 1.32.9__py3-none-any.whl → 1.32.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1920,7 +1920,8 @@ class TransformerWrapper(Module):
1920
1920
  l2norm_embed = False,
1921
1921
  emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
1922
1922
  attn_z_loss_weight = 1e-4,
1923
- average_pool_embed = False
1923
+ average_pool_embed = False,
1924
+ use_cls_token = False,
1924
1925
  ):
1925
1926
  super().__init__()
1926
1927
 
@@ -1966,6 +1967,16 @@ class TransformerWrapper(Module):
1966
1967
 
1967
1968
  assert num_output_heads > 0
1968
1969
 
1970
+ assert at_most_one_of(average_pool_embed, use_cls_token)
1971
+
1972
+ # classic cls token from the bert days
1973
+
1974
+ self.cls_token = None
1975
+
1976
+ if use_cls_token:
1977
+ self.cls_token = nn.Parameter(torch.zeros(dim))
1978
+ nn.init.normal_(self.cls_token, std = 0.02)
1979
+
1969
1980
  # whether to average pool the embed (`global average pool`)
1970
1981
 
1971
1982
  self.average_pool_embed = average_pool_embed
@@ -2092,7 +2103,19 @@ class TransformerWrapper(Module):
2092
2103
 
2093
2104
  x = self.project_emb(x)
2094
2105
 
2106
+ # maybe cls token
2107
+
2108
+ if exists(self.cls_token):
2109
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
2110
+ x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
2111
+
2112
+ if exists(mask):
2113
+ mask = F.pad(mask, (1, 0), value = True)
2114
+
2115
+ # maybe memory / register tokens
2116
+
2095
2117
  if has_memory_tokens:
2118
+ mem_seq = x.shape[-2]
2096
2119
  mem_every = self.memory_tokens_interspersed_every
2097
2120
 
2098
2121
  if exists(mem_every):
@@ -2132,13 +2155,16 @@ class TransformerWrapper(Module):
2132
2155
  if exists(mem_every):
2133
2156
  x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
2134
2157
 
2135
- x = x[:, :n]
2158
+ x = x[:, :mem_seq]
2136
2159
 
2137
2160
  # global average pool
2138
2161
 
2139
2162
  if self.average_pool_embed:
2140
2163
  x = masked_mean(x, mask = orig_mask, dim = 1)
2141
2164
 
2165
+ if exists(self.cls_token):
2166
+ x, _ = unpack(x, cls_packed_shape, 'b * d')
2167
+
2142
2168
  # projecting to logits
2143
2169
 
2144
2170
  if not return_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.9
3
+ Version: 1.32.10
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=8558TPHcDxWUvJYz01EdeyZl0lkHB14bzlsEMwSMPyw,77300
8
+ x_transformers/x_transformers.py,sha256=Ao3yHjEdl-qovGo9WW8q277wBHMgFxYRfcYRf1W_hKg,78076
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.9.dist-info/METADATA,sha256=-GidCdPhcKpZ49ElbeuJUPko5LZZP_vyEodaN_P3g48,661
13
- x_transformers-1.32.9.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.9.dist-info/RECORD,,
11
+ x_transformers-1.32.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.10.dist-info/METADATA,sha256=DMtabf-G60PL6axX1zSsTcWcHtzvHtKQNTxHuzOFJ4A,662
13
+ x_transformers-1.32.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.10.dist-info/RECORD,,