x-transformers 1.32.9__py3-none-any.whl → 1.32.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +48 -6
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.11.dist-info}/METADATA +1 -1
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.11.dist-info}/RECORD +6 -6
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.11.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.11.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.9.dist-info → x_transformers-1.32.11.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -45,15 +45,18 @@ def default(val, d):
|
|
45
45
|
return val
|
46
46
|
return d() if callable(d) else d
|
47
47
|
|
48
|
-
def
|
48
|
+
def first(it):
|
49
|
+
return it[0]
|
50
|
+
|
51
|
+
def is_empty(x):
|
52
|
+
return len(x) == 0
|
53
|
+
|
54
|
+
def cast_tuple(val, depth = 1):
|
49
55
|
return val if isinstance(val, tuple) else (val,) * depth
|
50
56
|
|
51
57
|
def divisible_by(num, den):
|
52
58
|
return (num % den) == 0
|
53
59
|
|
54
|
-
def is_empty(x):
|
55
|
-
return len(x) == 0
|
56
|
-
|
57
60
|
def maybe(fn):
|
58
61
|
@wraps(fn)
|
59
62
|
def inner(x, *args, **kwargs):
|
@@ -1920,7 +1923,9 @@ class TransformerWrapper(Module):
|
|
1920
1923
|
l2norm_embed = False,
|
1921
1924
|
emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1
|
1922
1925
|
attn_z_loss_weight = 1e-4,
|
1923
|
-
average_pool_embed = False
|
1926
|
+
average_pool_embed = False,
|
1927
|
+
use_cls_token = False,
|
1928
|
+
squeeze_out_last_dim = False
|
1924
1929
|
):
|
1925
1930
|
super().__init__()
|
1926
1931
|
|
@@ -1966,6 +1971,16 @@ class TransformerWrapper(Module):
|
|
1966
1971
|
|
1967
1972
|
assert num_output_heads > 0
|
1968
1973
|
|
1974
|
+
assert at_most_one_of(average_pool_embed, use_cls_token)
|
1975
|
+
|
1976
|
+
# classic cls token from the bert days
|
1977
|
+
|
1978
|
+
self.cls_token = None
|
1979
|
+
|
1980
|
+
if use_cls_token:
|
1981
|
+
self.cls_token = nn.Parameter(torch.zeros(dim))
|
1982
|
+
nn.init.normal_(self.cls_token, std = 0.02)
|
1983
|
+
|
1969
1984
|
# whether to average pool the embed (`global average pool`)
|
1970
1985
|
|
1971
1986
|
self.average_pool_embed = average_pool_embed
|
@@ -1995,6 +2010,10 @@ class TransformerWrapper(Module):
|
|
1995
2010
|
|
1996
2011
|
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
1997
2012
|
|
2013
|
+
# squeeze out last dimension if possible
|
2014
|
+
|
2015
|
+
self.squeeze_out_last_dim = squeeze_out_last_dim
|
2016
|
+
|
1998
2017
|
# whether can do cached kv decoding
|
1999
2018
|
|
2000
2019
|
self.can_cache_kv = self.num_memory_tokens == 0
|
@@ -2092,7 +2111,19 @@ class TransformerWrapper(Module):
|
|
2092
2111
|
|
2093
2112
|
x = self.project_emb(x)
|
2094
2113
|
|
2114
|
+
# maybe cls token
|
2115
|
+
|
2116
|
+
if exists(self.cls_token):
|
2117
|
+
cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
|
2118
|
+
x, cls_packed_shape = pack([cls_tokens, x], 'b * d')
|
2119
|
+
|
2120
|
+
if exists(mask):
|
2121
|
+
mask = F.pad(mask, (1, 0), value = True)
|
2122
|
+
|
2123
|
+
# maybe memory / register tokens
|
2124
|
+
|
2095
2125
|
if has_memory_tokens:
|
2126
|
+
mem_seq = x.shape[-2]
|
2096
2127
|
mem_every = self.memory_tokens_interspersed_every
|
2097
2128
|
|
2098
2129
|
if exists(mem_every):
|
@@ -2132,13 +2163,16 @@ class TransformerWrapper(Module):
|
|
2132
2163
|
if exists(mem_every):
|
2133
2164
|
x = rearrange(x, '(b n) m d -> b (n m) d', b = b)
|
2134
2165
|
|
2135
|
-
x = x[:, :
|
2166
|
+
x = x[:, :mem_seq]
|
2136
2167
|
|
2137
2168
|
# global average pool
|
2138
2169
|
|
2139
2170
|
if self.average_pool_embed:
|
2140
2171
|
x = masked_mean(x, mask = orig_mask, dim = 1)
|
2141
2172
|
|
2173
|
+
if exists(self.cls_token):
|
2174
|
+
x, _ = unpack(x, cls_packed_shape, 'b * d')
|
2175
|
+
|
2142
2176
|
# projecting to logits
|
2143
2177
|
|
2144
2178
|
if not return_embeddings:
|
@@ -2147,6 +2181,14 @@ class TransformerWrapper(Module):
|
|
2147
2181
|
else:
|
2148
2182
|
logits = self.to_logits(x)
|
2149
2183
|
|
2184
|
+
# maybe squeeze out last dimension of logits
|
2185
|
+
|
2186
|
+
if self.squeeze_out_last_dim:
|
2187
|
+
logits = tuple(rearrange(t, '... 1 -> ...') for t in cast_tuple(logits))
|
2188
|
+
|
2189
|
+
if not self.has_multiple_heads:
|
2190
|
+
logits = first(logits)
|
2191
|
+
|
2150
2192
|
# different returns
|
2151
2193
|
|
2152
2194
|
if return_logits_and_embeddings:
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=Z04p-xySEkTgHSaY_060M0RlF6LnkK8ko5yTLunIYf8,78520
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.11.dist-info/METADATA,sha256=dVUc_T7ALnVvaSKHSiPKOR4Y4zAo19l0QedFQPOEbN8,662
|
13
|
+
x_transformers-1.32.11.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.11.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|