x-transformers 1.32.2__py3-none-any.whl → 1.32.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1899,6 +1899,7 @@ class TransformerWrapper(Module):
1899
1899
  memory_tokens_interspersed_every = None,
1900
1900
  tie_embedding = False,
1901
1901
  logits_dim = None,
1902
+ return_only_embed = False,
1902
1903
  num_output_heads = 1,
1903
1904
  use_abs_pos_emb = True,
1904
1905
  scaled_sinu_pos_emb = False,
@@ -1948,13 +1949,17 @@ class TransformerWrapper(Module):
1948
1949
 
1949
1950
  self.init_()
1950
1951
 
1952
+ assert num_output_heads > 0
1953
+
1951
1954
  # output head, usually to logits of num_tokens
1952
1955
 
1953
1956
  logits_dim = default(logits_dim, num_tokens)
1954
1957
 
1955
1958
  self.has_multiple_heads = False
1956
1959
 
1957
- if tie_embedding:
1960
+ if return_only_embed:
1961
+ self.to_logits = None
1962
+ elif tie_embedding:
1958
1963
  self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
1959
1964
  elif num_output_heads > 1:
1960
1965
  self.has_multiple_heads = True
@@ -2008,7 +2013,9 @@ class TransformerWrapper(Module):
2008
2013
  **kwargs
2009
2014
  ):
2010
2015
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
2016
+
2011
2017
  return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2018
+ return_embeddings = return_embeddings | (not exists(self.to_logits))
2012
2019
 
2013
2020
  # absolute positional embedding
2014
2021
 
@@ -2018,6 +2025,8 @@ class TransformerWrapper(Module):
2018
2025
 
2019
2026
  # add additional embeddings
2020
2027
 
2028
+ assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
2029
+
2021
2030
  if exists(self.embeds):
2022
2031
  assert len(embed_ids) == len(self.embeds)
2023
2032
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: x-transformers
3
- Version: 1.32.2
3
+ Version: 1.32.4
4
4
  Summary: X-Transformers - Pytorch
5
5
  Home-page: https://github.com/lucidrains/x-transformers
6
6
  Author: Phil Wang
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
4
4
  x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
5
5
  x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
6
6
  x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
7
- x_transformers/x_transformers.py,sha256=1QG7zUe89h1R5VDMoKEAkvdRRDkzQ7h6npkqblxxR6g,76312
7
+ x_transformers/x_transformers.py,sha256=o7At5F35Paih1a_-rxqcIP1n4B-ARVJ_ZL2QkOnTnSQ,76655
8
8
  x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
9
9
  x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
10
- x_transformers-1.32.2.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
- x_transformers-1.32.2.dist-info/METADATA,sha256=U0Kh4e7UiL-0hLDZb0P3McdvTnzTeFyVwtoXFffzQ-M,661
12
- x_transformers-1.32.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
- x_transformers-1.32.2.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
- x_transformers-1.32.2.dist-info/RECORD,,
10
+ x_transformers-1.32.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
11
+ x_transformers-1.32.4.dist-info/METADATA,sha256=rAX2cTBnI50T0Tsa00KgBTz6lzV36ACppE6H2WPLZI4,661
12
+ x_transformers-1.32.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
13
+ x_transformers-1.32.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
14
+ x_transformers-1.32.4.dist-info/RECORD,,