x-transformers 1.40.9__py3-none-any.whl → 1.40.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +7 -4
- {x_transformers-1.40.9.dist-info → x_transformers-1.40.11.dist-info}/METADATA +1 -1
- {x_transformers-1.40.9.dist-info → x_transformers-1.40.11.dist-info}/RECORD +6 -6
- {x_transformers-1.40.9.dist-info → x_transformers-1.40.11.dist-info}/LICENSE +0 -0
- {x_transformers-1.40.9.dist-info → x_transformers-1.40.11.dist-info}/WHEEL +0 -0
- {x_transformers-1.40.9.dist-info → x_transformers-1.40.11.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1019,7 +1019,7 @@ class Attention(Module):
|
|
1019
1019
|
self.qk_norm_q_scale = self.qk_norm_k_scale = 1
|
1020
1020
|
if qk_norm and qk_norm_dim_scale:
|
1021
1021
|
self.qk_norm_q_scale = nn.Parameter(torch.ones(heads, 1, dim_head))
|
1022
|
-
self.qk_norm_k_scale = nn.Parameter(torch.ones(
|
1022
|
+
self.qk_norm_k_scale = nn.Parameter(torch.ones(kv_heads, 1, dim_head))
|
1023
1023
|
|
1024
1024
|
assert (not qk_norm) or divisible_by(dim_head, qk_norm_groups), 'dimension per attention head must be divisible by the qk norm groups'
|
1025
1025
|
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'
|
@@ -2104,6 +2104,7 @@ class TransformerWrapper(Module):
|
|
2104
2104
|
attn_z_loss_weight = 1e-4,
|
2105
2105
|
average_pool_embed = False,
|
2106
2106
|
use_cls_token = False,
|
2107
|
+
num_cls_tokens = 1,
|
2107
2108
|
squeeze_out_last_dim = False,
|
2108
2109
|
token_emb: TokenEmbedding | None = None,
|
2109
2110
|
mixture_of_softmax = False,
|
@@ -2116,6 +2117,7 @@ class TransformerWrapper(Module):
|
|
2116
2117
|
emb_dim = default(emb_dim, dim)
|
2117
2118
|
self.emb_dim = emb_dim
|
2118
2119
|
self.num_tokens = num_tokens
|
2120
|
+
self.num_cls_tokens = num_cls_tokens
|
2119
2121
|
|
2120
2122
|
self.max_seq_len = max_seq_len
|
2121
2123
|
self.max_mem_len = max_mem_len
|
@@ -2172,7 +2174,7 @@ class TransformerWrapper(Module):
|
|
2172
2174
|
self.cls_token = None
|
2173
2175
|
|
2174
2176
|
if use_cls_token:
|
2175
|
-
self.cls_token = nn.Parameter(torch.zeros(dim))
|
2177
|
+
self.cls_token = nn.Parameter(torch.zeros(num_cls_tokens, dim))
|
2176
2178
|
nn.init.normal_(self.cls_token, std = 0.02)
|
2177
2179
|
|
2178
2180
|
# whether to average pool the embed (`global average pool`)
|
@@ -2329,11 +2331,11 @@ class TransformerWrapper(Module):
|
|
2329
2331
|
# maybe cls token
|
2330
2332
|
|
2331
2333
|
if exists(self.cls_token):
|
2332
|
-
cls_tokens = repeat(self.cls_token, '
|
2334
|
+
cls_tokens = repeat(self.cls_token, '... -> b ...', b = b)
|
2333
2335
|
x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
|
2334
2336
|
|
2335
2337
|
if exists(mask):
|
2336
|
-
mask = F.pad(mask, (
|
2338
|
+
mask = F.pad(mask, (self.num_cls_tokens, 0), value = True)
|
2337
2339
|
|
2338
2340
|
# maybe memory / register tokens
|
2339
2341
|
|
@@ -2415,6 +2417,7 @@ class TransformerWrapper(Module):
|
|
2415
2417
|
|
2416
2418
|
if exists(self.cls_token):
|
2417
2419
|
x, _ = unpack(x, cls_packed_shape, 'b * d')
|
2420
|
+
x = x.squeeze(1) # Remove sequence dimension if num_cls_tokens=1 to keep previous behavior
|
2418
2421
|
|
2419
2422
|
# handle expansion to mixture if needed (for mixture of softmax)
|
2420
2423
|
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=RfpihlGygZz4ICq4IGOgGNOipInXUiYWYNs1tej2Orw,88290
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
10
10
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
11
|
-
x_transformers-1.40.
|
12
|
-
x_transformers-1.40.
|
13
|
-
x_transformers-1.40.
|
14
|
-
x_transformers-1.40.
|
15
|
-
x_transformers-1.40.
|
11
|
+
x_transformers-1.40.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.40.11.dist-info/METADATA,sha256=D97orsPC5EYEtJN6EN75bLOfOY-FBmodr2eaFIovwu8,662
|
13
|
+
x_transformers-1.40.11.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
|
14
|
+
x_transformers-1.40.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.40.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|