x-transformers 2.1.35__py3-none-any.whl → 2.1.37__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -864,14 +864,15 @@ class DynamicTanh(Module):
864
864
  self.gamma = nn.Parameter(torch.ones(dim))
865
865
  self.beta = nn.Parameter(torch.zeros(dim))
866
866
 
867
- self.unit_offset = int(unit_offset)
867
+ self.pre_tanh_scale_offset = init_alpha if unit_offset else 0.
868
+ self.gamma_offset = float(unit_offset)
868
869
 
869
- nn.init.constant_(self.pre_tanh_scale, 1. - float(unit_offset))
870
+ nn.init.constant_(self.pre_tanh_scale, 0 if unit_offset else init_alpha)
870
871
  nn.init.constant_(self.gamma, 1. - float(unit_offset))
871
872
 
872
873
  def forward(self, x):
873
- pre_tanh_scale = self.pre_tanh_scale + self.unit_offset
874
- gamma = self.gamma + self.unit_offset
874
+ pre_tanh_scale = self.pre_tanh_scale + self.pre_tanh_scale_offset
875
+ gamma = self.gamma + self.gamma_offset
875
876
  return (x * pre_tanh_scale).tanh() * gamma + self.beta
876
877
 
877
878
  # residual and residual gates
@@ -2908,6 +2909,7 @@ class TransformerWrapper(Module):
2908
2909
  return_embeddings = False,
2909
2910
  return_logits_and_embeddings = False,
2910
2911
  return_intermediates = False,
2912
+ return_embeddings_and_intermediates = False,
2911
2913
  return_logit_entropies = False,
2912
2914
  mask = None,
2913
2915
  return_mems = False,
@@ -2939,8 +2941,8 @@ class TransformerWrapper(Module):
2939
2941
 
2940
2942
  b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
2941
2943
 
2942
- return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
2943
- return_embeddings = return_embeddings | (not exists(self.to_logits))
2944
+ return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
2945
+ return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
2944
2946
 
2945
2947
  # absolute positional embedding
2946
2948
 
@@ -3130,6 +3132,8 @@ class TransformerWrapper(Module):
3130
3132
 
3131
3133
  if return_logits_and_embeddings:
3132
3134
  out = (logits, x)
3135
+ elif return_embeddings_and_intermediates:
3136
+ out = (x, intermediates)
3133
3137
  elif return_embeddings:
3134
3138
  out = x
3135
3139
  else:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.1.35
3
+ Version: 2.1.37
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
8
8
  x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
9
9
  x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
10
- x_transformers/x_transformers.py,sha256=oyYk31qhDyt6cCuDeaHBl4XhUo5bfIwmYE_z1E1OpXU,111502
10
+ x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
11
11
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
12
12
  x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
13
- x_transformers-2.1.35.dist-info/METADATA,sha256=tLbl-c1QtaOphTa1DpdNfh4dXzFwTt9Fvdh94tnwdTs,88161
14
- x_transformers-2.1.35.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
- x_transformers-2.1.35.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
- x_transformers-2.1.35.dist-info/RECORD,,
13
+ x_transformers-2.1.37.dist-info/METADATA,sha256=uaCIy-GAGH4OPrYa0mxjJJ-FDtMlMuiIbg1sQPb3BRw,88161
14
+ x_transformers-2.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ x_transformers-2.1.37.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
16
+ x_transformers-2.1.37.dist-info/RECORD,,