x-transformers 1.32.10__py3-none-any.whl → 1.32.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +20 -4
- {x_transformers-1.32.10.dist-info → x_transformers-1.32.12.dist-info}/METADATA +1 -1
- {x_transformers-1.32.10.dist-info → x_transformers-1.32.12.dist-info}/RECORD +6 -6
- {x_transformers-1.32.10.dist-info → x_transformers-1.32.12.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.10.dist-info → x_transformers-1.32.12.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.10.dist-info → x_transformers-1.32.12.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -45,15 +45,18 @@ def default(val, d):
|
|
45
45
|
return val
|
46
46
|
return d() if callable(d) else d
|
47
47
|
|
48
|
-
def
|
48
|
+
def first(it):
|
49
|
+
return it[0]
|
50
|
+
|
51
|
+
def is_empty(x):
|
52
|
+
return len(x) == 0
|
53
|
+
|
54
|
+
def cast_tuple(val, depth = 1):
|
49
55
|
return val if isinstance(val, tuple) else (val,) * depth
|
50
56
|
|
51
57
|
def divisible_by(num, den):
|
52
58
|
return (num % den) == 0
|
53
59
|
|
54
|
-
def is_empty(x):
|
55
|
-
return len(x) == 0
|
56
|
-
|
57
60
|
def maybe(fn):
|
58
61
|
@wraps(fn)
|
59
62
|
def inner(x, *args, **kwargs):
|
@@ -1922,6 +1925,7 @@ class TransformerWrapper(Module):
|
|
1922
1925
|
attn_z_loss_weight = 1e-4,
|
1923
1926
|
average_pool_embed = False,
|
1924
1927
|
use_cls_token = False,
|
1928
|
+
squeeze_out_last_dim = False
|
1925
1929
|
):
|
1926
1930
|
super().__init__()
|
1927
1931
|
|
@@ -2006,6 +2010,10 @@ class TransformerWrapper(Module):
|
|
2006
2010
|
|
2007
2011
|
self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
|
2008
2012
|
|
2013
|
+
# squeeze out last dimension if possible
|
2014
|
+
|
2015
|
+
self.squeeze_out_last_dim = squeeze_out_last_dim
|
2016
|
+
|
2009
2017
|
# whether can do cached kv decoding
|
2010
2018
|
|
2011
2019
|
self.can_cache_kv = self.num_memory_tokens == 0
|
@@ -2173,6 +2181,14 @@ class TransformerWrapper(Module):
|
|
2173
2181
|
else:
|
2174
2182
|
logits = self.to_logits(x)
|
2175
2183
|
|
2184
|
+
# maybe squeeze out last dimension of logits
|
2185
|
+
|
2186
|
+
if self.squeeze_out_last_dim:
|
2187
|
+
logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
|
2188
|
+
|
2189
|
+
if not self.has_multiple_heads:
|
2190
|
+
logits = first(logits)
|
2191
|
+
|
2176
2192
|
# different returns
|
2177
2193
|
|
2178
2194
|
if return_logits_and_embeddings:
|
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
|
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
7
7
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
8
|
-
x_transformers/x_transformers.py,sha256=
|
8
|
+
x_transformers/x_transformers.py,sha256=nsuYDfF4GY4kTImXEFqygnpw5mO8DOqaD_PJaeOxFS4,78549
|
9
9
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
10
10
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
15
|
-
x_transformers-1.32.
|
11
|
+
x_transformers-1.32.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
12
|
+
x_transformers-1.32.12.dist-info/METADATA,sha256=oOwIIjHp8Bl1ClFKTGaiNAX3RNK46C6jmriZEbyWYvM,662
|
13
|
+
x_transformers-1.32.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
14
|
+
x_transformers-1.32.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
15
|
+
x_transformers-1.32.12.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|