x-transformers 1.40.9__py3-none-any.whl → 1.40.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1019,7 +1019,7 @@ class Attention(Module):
1019
1019
  self.qk_norm_q_scale = self.qk_norm_k_scale = 1
1020
1020
  if qk_norm and qk_norm_dim_scale:
1021
1021
  self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1022
- self.qk_norm_k_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
1022
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(kv_heads, 1, dim_head))
1023
1023
 
1024
1024
  assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
1025
1025
  assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
@@ -2104,6 +2104,7 @@ class TransformerWrapper(Module):
2104
2104
  attn_z_loss_weight = 1e-4,
2105
2105
  average_pool_embed = False,
2106
2106
  use_cls_token = False,
2107
+ num_cls_tokens = 1,
2107
2108
  squeeze_out_last_dim = False,
2108
2109
  token_emb: TokenEmbedding | None = None,
2109
2110
  mixture_of_softmax = False,
@@ -2116,6 +2117,7 @@ class TransformerWrapper(Module):
2116
2117
  emb_dim = default(emb_dim, dim)
2117
2118
  self.emb_dim = emb_dim
2118
2119
  self.num_tokens = num_tokens
2120
+ self.num_cls_tokens = num_cls_tokens
2119
2121
 
2120
2122
  self.max_seq_len = max_seq_len
2121
2123
  self.max_mem_len = max_mem_len
@@ -2172,7 +2174,7 @@ class TransformerWrapper(Module):
2172
2174
  self.cls_token = None
2173
2175
 
2174
2176
  if use_cls_token:
2175
- self.cls_token = nn.Parameter(torch.zeros(dim))
2177
+ self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
2176
2178
  nn.init.normal_(self.cls_token, std = 0.02)
2177
2179
 
2178
2180
  # whether to average pool the embed (`global average pool`)
@@ -2329,11 +2331,11 @@ class TransformerWrapper(Module):
2329
2331
  # maybe cls token
2330
2332
 
2331
2333
  if exists(self.cls_token):
2332
- cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
2334
+ cls_tokens = repeat(self.cls_token, '... -> b ...', b = b)
2333
2335
  x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
2334
2336
 
2335
2337
  if exists(mask):
2336
- mask = F.pad(mask, (1, 0), value = True)
2338
+ mask = F.pad(mask, (self.num_cls_tokens, 0), value = True)
2337
2339
 
2338
2340
  # maybe memory / register tokens
2339
2341
 
@@ -2415,6 +2417,7 @@ class TransformerWrapper(Module):
2415
2417
 
2416
2418
  if exists(self.cls_token):
2417
2419
  x, _ = unpack(x, cls_packed_shape, 'b * d')
2420
+ x = x.squeeze(1) # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
2418
2421
 
2419
2422
  # handle expansion to mixture if needed (for mixture of softmax)
2420
2423
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.40.9
3
+ Version: 1.40.11
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
8
- x_transformers/x_transformers.py,sha256=NoWBoiz1t8_QytYo1T2YBFk-7H9s38k2t-EksxqUkMU,88072
8
+ x_transformers/x_transformers.py,sha256=RfpihlGygZz4ICq4IGOgGNOipInXUiYWYNs1tej2Orw,88290
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
10
10
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
11
- x_transformers-1.40.9.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.40.9.dist-info/METADATA,sha256=xSxqFkhGfr5dU2xI0xo3UzlPMSuaaR4Rd2TrDpEyxcE,661
13
- x_transformers-1.40.9.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
- x_transformers-1.40.9.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.40.9.dist-info/RECORD,,
11
+ x_transformers-1.40.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.40.11.dist-info/METADATA,sha256=D97orsPC5EYEtJN6EN75bLOfOY-FBmodr2eaFIovwu8,662
13
+ x_transformers-1.40.11.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
14
+ x_transformers-1.40.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.40.11.dist-info/RECORD,,