x-transformers 1.32.10__py3-none-any.whl → 1.32.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -45,15 +45,18 @@ def default(val, d):
45
45
  return val
46
46
  return d() if callable(d) else d
47
47
 
48
- def cast_tuple(val, depth):
48
+ def first(it):
49
+ return it[0]
50
+
51
+ def is_empty(x):
52
+ return len(x) == 0
53
+
54
+ def cast_tuple(val, depth = 1):
49
55
  return val if isinstance(val, tuple) else (val,) * depth
50
56
 
51
57
  def divisible_by(num, den):
52
58
  return (num % den) == 0
53
59
 
54
- def is_empty(x):
55
- return len(x) == 0
56
-
57
60
  def maybe(fn):
58
61
  @wraps(fn)
59
62
  def inner(x, *args, **kwargs):
@@ -1922,6 +1925,7 @@ class TransformerWrapper(Module):
1922
1925
  attn_z_loss_weight = 1e-4,
1923
1926
  average_pool_embed = False,
1924
1927
  use_cls_token = False,
1928
+ squeeze_out_last_dim = False
1925
1929
  ):
1926
1930
  super().__init__()
1927
1931
 
@@ -2006,6 +2010,10 @@ class TransformerWrapper(Module):
2006
2010
 
2007
2011
  self.memory_tokens_interspersed_every = memory_tokens_interspersed_every
2008
2012
 
2013
+ # squeeze out last dimension if possible
2014
+
2015
+ self.squeeze_out_last_dim = squeeze_out_last_dim
2016
+
2009
2017
  # whether can do cached kv decoding
2010
2018
 
2011
2019
  self.can_cache_kv = self.num_memory_tokens == 0
@@ -2173,6 +2181,14 @@ class TransformerWrapper(Module):
2173
2181
  else:
2174
2182
  logits = self.to_logits(x)
2175
2183
 
2184
+ # maybe squeeze out last dimension of logits
2185
+
2186
+ if self.squeeze_out_last_dim:
2187
+ logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
2188
+
2189
+ if not self.has_multiple_heads:
2190
+ logits = first(logits)
2191
+
2176
2192
  # different returns
2177
2193
 
2178
2194
  if return_logits_and_embeddings:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.10
3
+ Version: 1.32.12
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=Ao3yHjEdl-qovGo9WW8q277wBHMgFxYRfcYRf1W_hKg,78076
8
+ x_transformers/x_transformers.py,sha256=nsuYDfF4GY4kTImXEFqygnpw5mO8DOqaD_PJaeOxFS4,78549
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.10.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.10.dist-info/METADATA,sha256=DMtabf-G60PL6axX1zSsTcWcHtzvHtKQNTxHuzOFJ4A,662
13
- x_transformers-1.32.10.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.10.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.10.dist-info/RECORD,,
11
+ x_transformers-1.32.12.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.12.dist-info/METADATA,sha256=oOwIIjHp8Bl1ClFKTGaiNAX3RNK46C6jmriZEbyWYvM,662
13
+ x_transformers-1.32.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.12.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.12.dist-info/RECORD,,