x-transformers 2.3.0__py3-none-any.whl → 2.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +2 -1
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.1.dist-info}/METADATA +1 -1
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.1.dist-info}/RECORD +5 -5
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.1.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.0.dist-info → x_transformers-2.3.1.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -222,6 +222,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
222
222
|
self,
|
223
223
|
start_tokens,
|
224
224
|
seq_len,
|
225
|
+
temperature = 1.,
|
225
226
|
**kwargs
|
226
227
|
):
|
227
228
|
device = start_tokens.device
|
@@ -245,7 +246,7 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
245
246
|
|
246
247
|
if self.probabilistic:
|
247
248
|
mean, var = last_output
|
248
|
-
last_output = torch.normal(mean, var)
|
249
|
+
last_output = torch.normal(mean, var * temperature)
|
249
250
|
|
250
251
|
out = cat((out, last_output), dim = -2)
|
251
252
|
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=-5BWWhFsp7tvZTdN91Ay5SqOjyj9uOs-122vFvoO6b4,17253
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=reLCno9Z9pchVU79tBF8OMo21LwSZ67KAeB83jqkyAc,10505
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=F5XPQU5Y798R1_JoepX4Mg44_j3Whw8SHaTsavq1YZs,8256
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=MF91aJGr2DOjIGe57uqwgyNxCExBg_tI9z7usAJMxOM,112401
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=7S00kCuab4tWQa-vf-z-XfzADjVj48MoFIr7VSIvttg,8575
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.1.dist-info/METADATA,sha256=-y3iEikqisIdIx8eBfP41qVZj2Nqzpm88usIUek6Pwg,88686
|
15
|
+
x_transformers-2.3.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.1.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|