x-transformers 1.32.2__py3-none-any.whl → 1.32.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +10 -1
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.4.dist-info}/METADATA +1 -1
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.4.dist-info}/RECORD +6 -6
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.4.dist-info}/LICENSE +0 -0
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.4.dist-info}/WHEEL +0 -0
- {x_transformers-1.32.2.dist-info → x_transformers-1.32.4.dist-info}/top_level.txt +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -1899,6 +1899,7 @@ class TransformerWrapper(Module):
|
|
1899
1899
|
memory_tokens_interspersed_every = None,
|
1900
1900
|
tie_embedding = False,
|
1901
1901
|
logits_dim = None,
|
1902
|
+
return_only_embed = False,
|
1902
1903
|
num_output_heads = 1,
|
1903
1904
|
use_abs_pos_emb = True,
|
1904
1905
|
scaled_sinu_pos_emb = False,
|
@@ -1948,13 +1949,17 @@ class TransformerWrapper(Module):
|
|
1948
1949
|
|
1949
1950
|
self.init_()
|
1950
1951
|
|
1952
|
+
assert num_output_heads > 0
|
1953
|
+
|
1951
1954
|
# output head, usually to logits of num_tokens
|
1952
1955
|
|
1953
1956
|
logits_dim = default(logits_dim, num_tokens)
|
1954
1957
|
|
1955
1958
|
self.has_multiple_heads = False
|
1956
1959
|
|
1957
|
-
if
|
1960
|
+
if return_only_embed:
|
1961
|
+
self.to_logits = None
|
1962
|
+
elif tie_embedding:
|
1958
1963
|
self.to_logits = lambda t: t @ self.token_emb.emb.weight.t()
|
1959
1964
|
elif num_output_heads > 1:
|
1960
1965
|
self.has_multiple_heads = True
|
@@ -2008,7 +2013,9 @@ class TransformerWrapper(Module):
|
|
2008
2013
|
**kwargs
|
2009
2014
|
):
|
2010
2015
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient
|
2016
|
+
|
2011
2017
|
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2018
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
2012
2019
|
|
2013
2020
|
# absolute positional embedding
|
2014
2021
|
|
@@ -2018,6 +2025,8 @@ class TransformerWrapper(Module):
|
|
2018
2025
|
|
2019
2026
|
# add additional embeddings
|
2020
2027
|
|
2028
|
+
assert not (exists(self.embeds) ^ (len(embed_ids) > 0)), '`embed_num_tokens` must be defined on `TransformerWrapper`'
|
2029
|
+
|
2021
2030
|
if exists(self.embeds):
|
2022
2031
|
assert len(embed_ids) == len(self.embeds)
|
2023
2032
|
|
@@ -4,11 +4,11 @@ x_transformers/autoregressive_wrapper.py,sha256=uX8Mb0zLsQrZECt_9UGt35g7tC05Rk3n
|
|
4
4
|
x_transformers/continuous.py,sha256=WO52n9lFAXv5-SGadi2cApGF8dkouN8QSTEOuC7erj8,6180
|
5
5
|
x_transformers/dpo.py,sha256=LjvWgCkqTl-UuehrzQ8nkX5guLr4whYwsmm7SKSwdls,3450
|
6
6
|
x_transformers/nonautoregressive_wrapper.py,sha256=ys_p8obc7lTeeodCqvkRKxOXQ1C9T3j5Jwr-JbVgnXk,10432
|
7
|
-
x_transformers/x_transformers.py,sha256=
|
7
|
+
x_transformers/x_transformers.py,sha256=o7At5F35Paih1a_-rxqcIP1n4B-ARVJ_ZL2QkOnTnSQ,76655
|
8
8
|
x_transformers/xl_autoregressive_wrapper.py,sha256=DCx4n0_c1tFai4nOqaWVnqx2p9eutsZsDMiMP1ckxNU,4117
|
9
9
|
x_transformers/xval.py,sha256=QE1ltYZTR_eGgIHPP2BrMWVWVLqMW-OpDZh87BSmQEg,8563
|
10
|
-
x_transformers-1.32.
|
11
|
-
x_transformers-1.32.
|
12
|
-
x_transformers-1.32.
|
13
|
-
x_transformers-1.32.
|
14
|
-
x_transformers-1.32.
|
10
|
+
x_transformers-1.32.4.dist-info/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
11
|
+
x_transformers-1.32.4.dist-info/METADATA,sha256=rAX2cTBnI50T0Tsa00KgBTz6lzV36ACppE6H2WPLZI4,661
|
12
|
+
x_transformers-1.32.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
13
|
+
x_transformers-1.32.4.dist-info/top_level.txt,sha256=hO6KGpFuGucRNEtRfme4A_rGcM53AKwGP7RVlRIxS5Q,15
|
14
|
+
x_transformers-1.32.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|