x-transformers 1.32.11__py3-none-any.whl → 1.32.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1925,7 +1925,8 @@ class TransformerWrapper(Module):
1925
1925
  attn_z_loss_weight = 1e-4,
1926
1926
  average_pool_embed = False,
1927
1927
  use_cls_token = False,
1928
- squeeze_out_last_dim = False
1928
+ squeeze_out_last_dim = False,
1929
+ token_emb: TokenEmbedding | None = None,
1929
1930
  ):
1930
1931
  super().__init__()
1931
1932
 
@@ -1939,7 +1940,11 @@ class TransformerWrapper(Module):
1939
1940
  self.shift_mem_down = shift_mem_down
1940
1941
 
1941
1942
  self.l2norm_embed = l2norm_embed
1942
- self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
1943
+
1944
+ if not exists(token_emb):
1945
+ token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
1946
+
1947
+ self.token_emb = token_emb
1943
1948
 
1944
1949
  no_abs_pos_emb = max_seq_len == 0 or not (use_abs_pos_emb and not attn_layers.disable_abs_pos_emb)
1945
1950
 
@@ -2184,7 +2189,7 @@ class TransformerWrapper(Module):
2184
2189
  # maybe squeeze out last dimension of logits
2185
2190
 
2186
2191
  if self.squeeze_out_last_dim:
2187
- logits = tuple(rearrange(t, '... 1 -> ...') for t in cast_tuple(logits))
2192
+ logits = tuple((rearrange(t, '... 1 -> ...') if t.shape[-1] == 1 else t) for t in cast_tuple(logits))
2188
2193
 
2189
2194
  if not self.has_multiple_heads:
2190
2195
  logits = first(logits)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.11
3
+ Version: 1.32.14
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -5,11 +5,11 @@ x_transformers/continuous.py,sha256=cIVEdhfei258__ziV7kQBrJMxCel54bExBTDrO9rfCI,
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
7
7
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
8
- x_transformers/x_transformers.py,sha256=Z04p-xySEkTgHSaY_060M0RlF6LnkK8ko5yTLunIYf8,78520
8
+ x_transformers/x_transformers.py,sha256=pyRQ6lb1Sx1CbjOH882tAv9UhAzsLwIeXDPBOsiRipg,78669
9
9
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
10
10
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
11
- x_transformers-1.32.11.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
- x_transformers-1.32.11.dist-info/METADATA,sha256=dVUc_T7ALnVvaSKHSiPKOR4Y4zAo19l0QedFQPOEbN8,662
13
- x_transformers-1.32.11.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
- x_transformers-1.32.11.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
- x_transformers-1.32.11.dist-info/RECORD,,
11
+ x_transformers-1.32.14.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
12
+ x_transformers-1.32.14.dist-info/METADATA,sha256=ChbSwpAxqxzvSjyOlJb9yO1GnoNZk2c0ioq5F-NuHI0,662
13
+ x_transformers-1.32.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
14
+ x_transformers-1.32.14.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
15
+ x_transformers-1.32.14.dist-info/RECORD,,