x-transformers 2.3.17__py3-none-any.whl → 2.3.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- x_transformers/continuous.py +8 -3
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.18.dist-info}/METADATA +1 -1
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.18.dist-info}/RECORD +5 -5
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.18.dist-info}/WHEEL +0 -0
- {x_transformers-2.3.17.dist-info → x_transformers-2.3.18.dist-info}/licenses/LICENSE +0 -0
x_transformers/continuous.py
CHANGED
@@ -298,7 +298,6 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
298
298
|
**kwargs
|
299
299
|
):
|
300
300
|
assert rollout_steps > 1
|
301
|
-
assert not self.probabilistic, 'probabilistic not supported yet'
|
302
301
|
|
303
302
|
steps = rollout_steps
|
304
303
|
|
@@ -369,8 +368,14 @@ class ContinuousAutoregressiveWrapper(Module):
|
|
369
368
|
**kwargs
|
370
369
|
)
|
371
370
|
|
372
|
-
last_pred = out[
|
373
|
-
|
371
|
+
last_pred = out[..., -1:, :]
|
372
|
+
|
373
|
+
if self.probabilistic:
|
374
|
+
mean, var = last_pred
|
375
|
+
std = var.clamp(min = 1e-5).sqrt()
|
376
|
+
inp = torch.normal(mean, std)
|
377
|
+
else:
|
378
|
+
inp = last_pred
|
374
379
|
|
375
380
|
preds.append(last_pred)
|
376
381
|
|
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
|
|
2
2
|
x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
|
3
3
|
x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
|
4
4
|
x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
|
5
|
-
x_transformers/continuous.py,sha256=
|
5
|
+
x_transformers/continuous.py,sha256=uV2hLQOckeRsybqJy-0F8RhAyMPJlkVHmA7QqUJHG4g,12433
|
6
6
|
x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
|
7
7
|
x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
|
8
8
|
x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
|
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
|
|
11
11
|
x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
|
12
12
|
x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
|
13
13
|
x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
|
14
|
-
x_transformers-2.3.
|
15
|
-
x_transformers-2.3.
|
16
|
-
x_transformers-2.3.
|
17
|
-
x_transformers-2.3.
|
14
|
+
x_transformers-2.3.18.dist-info/METADATA,sha256=RKXNlO50fifu1Nas38iZRn6IJVDkv4Cen94XYVJlWg0,89897
|
15
|
+
x_transformers-2.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
x_transformers-2.3.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
|
17
|
+
x_transformers-2.3.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|