x-transformers 2.3.17__py3-none-any.whl → 2.3.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -298,7 +298,6 @@ class ContinuousAutoregressiveWrapper(Module):
298
298
  **kwargs
299
299
  ):
300
300
  assert rollout_steps > 1
301
- assert not self.probabilistic, 'probabilistic not supported yet'
302
301
 
303
302
  steps = rollout_steps
304
303
 
@@ -369,8 +368,14 @@ class ContinuousAutoregressiveWrapper(Module):
369
368
  **kwargs
370
369
  )
371
370
 
372
- last_pred = out[:, -1:]
373
- inp = last_pred
371
+ last_pred = out[..., -1:, :]
372
+
373
+ if self.probabilistic:
374
+ mean, var = last_pred
375
+ std = var.clamp(min = 1e-5).sqrt()
376
+ inp = torch.normal(mean, std)
377
+ else:
378
+ inp = last_pred
374
379
 
375
380
  preds.append(last_pred)
376
381
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: x-transformers
3
- Version: 2.3.17
3
+ Version: 2.3.18
4
4
  Summary: X-Transformers
5
5
  Project-URL: Homepage, https://pypi.org/project/x-transformers/
6
6
  Project-URL: Repository, https://github.com/lucidrains/x-transformers
@@ -2,7 +2,7 @@ x_transformers/__init__.py,sha256=h3I2ejobgEdy8H7NgV-rP8UaBCnd16-MysvDXH9GMEA,98
2
2
  x_transformers/attend.py,sha256=xFsBtl7h7_qebPh7kE81BpmCWAjCgFpB9i_IHu_91es,17288
3
3
  x_transformers/autoregressive_wrapper.py,sha256=LW1gr3cFONDEPA_HHhaTE7mk-JWbaINuB1fc_DfbCqw,10791
4
4
  x_transformers/belief_state_wrapper.py,sha256=YLUMk6t2MhFBEw5lHDDHJHcoCxTIkHvxTNY__GGZEKU,13374
5
- x_transformers/continuous.py,sha256=QidhbSgBhYDpAp7FaryxrNJxWXp0-pJkeygguTpVp4k,12308
5
+ x_transformers/continuous.py,sha256=uV2hLQOckeRsybqJy-0F8RhAyMPJlkVHmA7QqUJHG4g,12433
6
6
  x_transformers/dpo.py,sha256=xt4OuOWhU8pN3OKN2LZAaC2NC8iiEnchqqcrPWVqf0o,3521
7
7
  x_transformers/entropy_based_tokenizer.py,sha256=F2lO8-v3aLIcVDVNhu7RR-UtRdlmaaYJzBK9m7OnLE8,5018
8
8
  x_transformers/multi_input.py,sha256=tCh-fTJDj2ib4SMGtsa-AM8MxKzJAQSwqAXOu3HU2mg,9252
@@ -11,7 +11,7 @@ x_transformers/nonautoregressive_wrapper.py,sha256=2NU58hYMgn-4Jzg3mie-mXb0XH_dC
11
11
  x_transformers/x_transformers.py,sha256=ZfOXrZSiy2jlZ8wVmDdMTLW4hAY_qfmPQHW9t2ABxbo,114097
12
12
  x_transformers/xl_autoregressive_wrapper.py,sha256=CvZMJ6A6PA-Y_bQAhnORwjJBSl6Vjq2IdW5KTdk8NI8,4195
13
13
  x_transformers/xval.py,sha256=AwwYUm8yDAtKQyKJDIhYMsiLTJ_skh3scUFMjp5sda8,8597
14
- x_transformers-2.3.17.dist-info/METADATA,sha256=u_DVXsX7KsVhnfMhpV-3KV6KGNHddmzW_SNG9om557s,89897
15
- x_transformers-2.3.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- x_transformers-2.3.17.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
- x_transformers-2.3.17.dist-info/RECORD,,
14
+ x_transformers-2.3.18.dist-info/METADATA,sha256=RKXNlO50fifu1Nas38iZRn6IJVDkv4Cen94XYVJlWg0,89897
15
+ x_transformers-2.3.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ x_transformers-2.3.18.dist-info/licenses/LICENSE,sha256=As9u198X-U-vph5noInuUfqsAG2zX_oXPHDmdjwlPPY,1066
17
+ x_transformers-2.3.18.dist-info/RECORD,,