x-transformers 2.1.35__py3-none-any.whl → 2.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/x_transformers.py +10 -6
- {x_transformers-2.1.35.dist-info → x_transformers-2.1.37.dist-info}/METADATA +1 -1
- {x_transformers-2.1.35.dist-info → x_transformers-2.1.37.dist-info}/RECORD +5 -5
- {x_transformers-2.1.35.dist-info → x_transformers-2.1.37.dist-info}/WHEEL +0 -0
- {x_transformers-2.1.35.dist-info → x_transformers-2.1.37.dist-info}/licenses/LICENSE +0 -0
x_transformers/x_transformers.py
CHANGED
@@ -864,14 +864,15 @@ class DynamicTanh(Module):
|
|
864
864
|
self.gamma = nn.Parameter(torch.ones(dim))
|
865
865
|
self.beta = nn.Parameter(torch.zeros(dim))
|
866
866
|
|
867
|
-
self.
|
867
|
+
self.pre_tanh_scale_offset = init_alpha if unit_offset else 0.
|
868
|
+
self.gamma_offset = float(unit_offset)
|
868
869
|
|
869
|
-
nn.init.constant_(self.pre_tanh_scale,
|
870
|
+
nn.init.constant_(self.pre_tanh_scale, 0 if unit_offset else init_alpha)
|
870
871
|
nn.init.constant_(self.gamma, 1. - float(unit_offset))
|
871
872
|
|
872
873
|
def forward(self, x):
|
873
|
-
pre_tanh_scale = self.pre_tanh_scale + self.
|
874
|
-
gamma = self.gamma + self.
|
874
|
+
pre_tanh_scale = self.pre_tanh_scale + self.pre_tanh_scale_offset
|
875
|
+
gamma = self.gamma + self.gamma_offset
|
875
876
|
return (x * pre_tanh_scale).tanh() * gamma + self.beta
|
876
877
|
|
877
878
|
# residual and residual gates
|
@@ -2908,6 +2909,7 @@ class TransformerWrapper(Module):
|
|
2908
2909
|
return_embeddings = False,
|
2909
2910
|
return_logits_and_embeddings = False,
|
2910
2911
|
return_intermediates = False,
|
2912
|
+
return_embeddings_and_intermediates = False,
|
2911
2913
|
return_logit_entropies = False,
|
2912
2914
|
mask = None,
|
2913
2915
|
return_mems = False,
|
@@ -2939,8 +2941,8 @@ class TransformerWrapper(Module):
|
|
2939
2941
|
|
2940
2942
|
b, n, device, num_mems, has_memory_tokens, emb_frac_gradient, orig_mask = x.shape[0], x.shape[1], x.device, self.num_memory_tokens, self.num_memory_tokens > 0, self.emb_frac_gradient, mask
|
2941
2943
|
|
2942
|
-
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss
|
2943
|
-
return_embeddings = return_embeddings | (not exists(self.to_logits))
|
2944
|
+
return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss | return_embeddings_and_intermediates
|
2945
|
+
return_embeddings = return_embeddings | (not exists(self.to_logits)) | return_embeddings_and_intermediates
|
2944
2946
|
|
2945
2947
|
# absolute positional embedding
|
2946
2948
|
|
@@ -3130,6 +3132,8 @@ class TransformerWrapper(Module):
|
|
3130
3132
|
|
3131
3133
|
if return_logits_and_embeddings:
|
3132
3134
|
out = (logits, x)
|
3135
|
+
elif return_embeddings_and_intermediates:
|
3136
|
+
out = (x, intermediates)
|
3133
3137
|
elif return_embeddings:
|
3134
3138
|
out = x
|
3135
3139
|
else:
|
@@ -7,10 +7,10 @@ x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
|
7
7
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
8
8
|
x_transformers/neo_mlp.py,sha256=XCNnnop9WLarcxap1kGuYc1x8GHvwkZiDRnXOxSl3Po,3452
|
9
9
|
x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dCN7fjlzd3K1rLUY,10510
|
10
|
-
x_transformers/x_transformers.py,sha256=
|
10
|
+
x_transformers/x_transformers.py,sha256=twoqq2kfVWxntitHKLs2sxFMhK1CPLxGGBDAmkiHXcM,111812
|
11
11
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
12
12
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
13
|
-
x_transformers-2.1.
|
14
|
-
x_transformers-2.1.
|
15
|
-
x_transformers-2.1.
|
16
|
-
x_transformers-2.1.
|
13
|
+
x_transformers-2.1.37.dist-info/METADATA,sha256=uaCIy-GAGH4OPrYa0mxjJJ-FDtMlMuiIbg1sQPb3BRw,88161
|
14
|
+
x_transformers-2.1.37.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
15
|
+
x_transformers-2.1.37.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
16
|
+
x_transformers-2.1.37.dist-info/RECORD,,
|
File without changes
|
File without changes
|