x-transformers 2.1.36__py3-none-any.whl → 2.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2909,6 +2909,7 @@ class TransformerWrapper(Module):
2909
2909
  return_embeddings = False,
2910
2910
  return_logits_and_embeddings = False,
2911
2911
  return_intermediates = False,
2912
+ return_embeddings_and_intermediates = False,
2912
2913
  return_logit_entropies = False,
2913
2914
  mask = None,
2914
2915
  return_mems = False,
@@ -2940,8 +2941,8 @@ class TransformerWrapper(Module):
2940
2941
 
2941
2942
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2942
2943
 
2943
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2944
- return_embeddings = return_embeddings | (not exists(self.to_logits))
2944
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
2945
+ return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
2945
2946
 
2946
2947
  # absolute positional embedding
2947
2948
 
@@ -3131,6 +3132,8 @@ class TransformerWrapper(Module):
3131
3132
 
3132
3133
  if return_logits_and_embeddings:
3133
3134
  out = (logits, x)
3135
+ elif return_embeddings_and_intermediates:
3136
+ out = (x, intermediates)
3134
3137
  elif return_embeddings:
3135
3138
  out = x
3136
3139
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.36
3
+ Version: 2.1.37
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=voN-uEBEKxpUu9K4MVcneSTrzdgJWnZGuQ1QRZQw4Q4,111596
10
+ x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.36.dist-info/METADATA,sha256=D0qdMRucK3PWwEi8WwdiJdZ8X_hGTm1r3_7bJzYiWSM,88161
14
- x_transformers-2.1.36.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.36.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.36.dist-info/RECORD,,
13
+ x_transformers-2.1.37.dist-info/METADATA,sha256=uaCIy-GAGH4OPrYa0mxjJJ-FDtMlMuiIbg1sQPb3BRw,88161
14
+ x_transformers-2.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.37.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.37.dist-info/RECORD,,