x-transformers 1.32.9__py3-none-any.whl → 1.32.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +28 -2
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.10.dist-info}/METADATA +1 -1
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.10.dist-info}/RECORD +6 -6
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.10.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.10.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.10.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1920,7 +1920,8 @@ class TransformerWrapper(Module):
|
|
1920
1920
|
l2norm_embed = False,
|
1921
1921
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1922
1922
|
attn_z_loss_weight = 1e-4,
|
1923
|
-
average_pool_embed = False
|
1923
|
+
average_pool_embed = False,
|
1924
|
+
use_cls_token = False,
|
1924
1925
|
):
|
1925
1926
|
super().__init__()
|
1926
1927
|
|
@@ -1966,6 +1967,16 @@ class TransformerWrapper(Module):
|
|
1966
1967
|
|
1967
1968
|
assert num_output_heads > 0
|
1968
1969
|
|
1970
|
+
assert at_most_one_of(average_pool_embed, use_cls_token)
|
1971
|
+
|
1972
|
+
# classic cls token from the bert days
|
1973
|
+
|
1974
|
+
self.cls_token = None
|
1975
|
+
|
1976
|
+
if use_cls_token:
|
1977
|
+
self.cls_token = nn.Parameter(torch.zeros(dim))
|
1978
|
+
nn.init.normal_(self.cls_token, std = 0.02)
|
1979
|
+
|
1969
1980
|
# whether to average pool the embed (`global average pool`)
|
1970
1981
|
|
1971
1982
|
self.average_pool_embed = average_pool_embed
|
@@ -2092,7 +2103,19 @@ class TransformerWrapper(Module):
|
|
2092
2103
|
|
2093
2104
|
x = self.project_emb(x)
|
2094
2105
|
|
2106
|
+
# maybe cls token
|
2107
|
+
|
2108
|
+
if exists(self.cls_token):
|
2109
|
+
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
2110
|
+
x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
|
2111
|
+
|
2112
|
+
if exists(mask):
|
2113
|
+
mask = F.pad(mask, (1, 0), value = True)
|
2114
|
+
|
2115
|
+
# maybe memory / register tokens
|
2116
|
+
|
2095
2117
|
if has_memory_tokens:
|
2118
|
+
mem_seq = x.shape[-2]
|
2096
2119
|
mem_every = self.memory_tokens_interspersed_every
|
2097
2120
|
|
2098
2121
|
if exists(mem_every):
|
@@ -2132,13 +2155,16 @@ class TransformerWrapper(Module):
|
|
2132
2155
|
if exists(mem_every):
|
2133
2156
|
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
2134
2157
|
|
2135
|
-
x = x[:, :
|
2158
|
+
x = x[:, :mem_seq]
|
2136
2159
|
|
2137
2160
|
# global average pool
|
2138
2161
|
|
2139
2162
|
if self.average_pool_embed:
|
2140
2163
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
2141
2164
|
|
2165
|
+
if exists(self.cls_token):
|
2166
|
+
x, _ = unpack(x, cls_packed_shape, 'b * d')
|
2167
|
+
|
2142
2168
|
# projecting to logits
|
2143
2169
|
|
2144
2170
|
if not return_embeddings:
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=Ao3yHjEdl-qovGo9WW8q277wBHMgFxYRfcYRf1W_hKg,78076
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.10.dist-info/METADATA,sha256=DMtabf-G60PL6axX1zSsTcWcHtzvHtKQNTxHuzOFJ4A,662
|
13
|
+
x_transformers-1.32.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|