x-transformers 2.1.36__py3-none-any.whl → 2.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +5 -2
- {x_transformers-2.1.36.dist-info → x_transformers-2.1.37.dist-info}/METADATA +1 -1
- {x_transformers-2.1.36.dist-info → x_transformers-2.1.37.dist-info}/RECORD +5 -5
- {x_transformers-2.1.36.dist-info → x_transformers-2.1.37.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.36.dist-info → x_transformers-2.1.37.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -2909,6 +2909,7 @@ class TransformerWrapper(Module):
|
|
2909
2909
|
return_embeddings = False,
|
2910
2910
|
return_logits_and_embeddings = False,
|
2911
2911
|
return_intermediates = False,
|
2912
|
+
return_embeddings_and_intermediates = False,
|
2912
2913
|
return_logit_entropies = False,
|
2913
2914
|
mask = None,
|
2914
2915
|
return_mems = False,
|
@@ -2940,8 +2941,8 @@ class TransformerWrapper(Module):
|
|
2940
2941
|
|
2941
2942
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2942
2943
|
|
2943
|
-
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2944
|
-
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
2944
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
2945
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
2945
2946
|
|
2946
2947
|
# absolute positional embedding
|
2947
2948
|
|
@@ -3131,6 +3132,8 @@ class TransformerWrapper(Module):
|
|
3131
3132
|
|
3132
3133
|
if return_logits_and_embeddings:
|
3133
3134
|
out = (logits, x)
|
3135
|
+
elif return_embeddings_and_intermediates:
|
3136
|
+
out = (x, intermediates)
|
3134
3137
|
elif return_embeddings:
|
3135
3138
|
out = x
|
3136
3139
|
else:
|
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.37.dist-info/METADATA,sha256=uaCIy-GAGH4OPrYa0mxjJJ-FDtMlMuiIbg1sQPb3BRw,88161
|
14
|
+
x_transformers-2.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.37.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|