CUQIpy 1.3.0.post0.dev237__tar.gz → 1.3.0.post0.dev277__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (127) hide show
  1. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/CUQIpy.egg-info/SOURCES.txt +1 -1
  3. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/PKG-INFO +1 -1
  4. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/_version.py +3 -3
  5. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/__init__.py +1 -0
  6. cuqipy-1.3.0.post0.dev277/cuqi/experimental/_recommender.py +200 -0
  7. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/__init__.py +0 -1
  8. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_gibbs.py +28 -2
  9. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_implicit_priors.py +3 -3
  10. cuqipy-1.3.0.post0.dev237/cuqi/experimental/mcmc/_utilities.py +0 -17
  11. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/CUQIpy.egg-info/dependency_links.txt +0 -0
  12. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/CUQIpy.egg-info/requires.txt +0 -0
  13. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/CUQIpy.egg-info/top_level.txt +0 -0
  14. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/LICENSE +0 -0
  15. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/README.md +0 -0
  16. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/__init__.py +0 -0
  17. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/_messages.py +0 -0
  18. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/array/__init__.py +0 -0
  19. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/array/_array.py +0 -0
  20. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/config.py +0 -0
  21. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/__init__.py +0 -0
  22. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/_data.py +0 -0
  23. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/astronaut.npz +0 -0
  24. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/camera.npz +0 -0
  25. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/cat.npz +0 -0
  26. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/cookie.png +0 -0
  27. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/data/satellite.mat +0 -0
  28. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/density/__init__.py +0 -0
  29. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/density/_density.py +0 -0
  30. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/diagnostics.py +0 -0
  31. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/__init__.py +0 -0
  32. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_beta.py +0 -0
  33. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_cauchy.py +0 -0
  34. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_cmrf.py +0 -0
  35. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_custom.py +0 -0
  36. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_distribution.py +0 -0
  37. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_gamma.py +0 -0
  38. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_gaussian.py +0 -0
  39. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_gmrf.py +0 -0
  40. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_inverse_gamma.py +0 -0
  41. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_joint_distribution.py +0 -0
  42. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_laplace.py +0 -0
  43. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_lmrf.py +0 -0
  44. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_lognormal.py +0 -0
  45. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  46. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_normal.py +0 -0
  47. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_posterior.py +0 -0
  48. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_smoothed_laplace.py +0 -0
  49. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_truncated_normal.py +0 -0
  50. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/distribution/_uniform.py +0 -0
  51. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/algebra/__init__.py +0 -0
  52. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/algebra/_ast.py +0 -0
  53. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/algebra/_orderedset.py +0 -0
  54. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/algebra/_randomvariable.py +0 -0
  55. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/geometry/__init__.py +0 -0
  56. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/geometry/_productgeometry.py +0 -0
  57. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  58. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  59. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  60. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_direct.py +0 -0
  61. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_hmc.py +0 -0
  62. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  63. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  64. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_mh.py +0 -0
  65. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_pcn.py +0 -0
  66. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_rto.py +0 -0
  67. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/experimental/mcmc/_sampler.py +0 -0
  68. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/geometry/__init__.py +0 -0
  69. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/geometry/_geometry.py +0 -0
  70. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/implicitprior/__init__.py +0 -0
  71. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  72. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  73. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  74. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/implicitprior/_restorator.py +0 -0
  75. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/likelihood/__init__.py +0 -0
  76. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/likelihood/_likelihood.py +0 -0
  77. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/model/__init__.py +0 -0
  78. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/model/_model.py +0 -0
  79. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/operator/__init__.py +0 -0
  80. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/operator/_operator.py +0 -0
  81. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/pde/__init__.py +0 -0
  82. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/pde/_pde.py +0 -0
  83. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/problem/__init__.py +0 -0
  84. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/problem/_problem.py +0 -0
  85. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/__init__.py +0 -0
  86. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_conjugate.py +0 -0
  87. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_conjugate_approx.py +0 -0
  88. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_cwmh.py +0 -0
  89. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_gibbs.py +0 -0
  90. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_hmc.py +0 -0
  91. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_langevin_algorithm.py +0 -0
  92. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_laplace_approximation.py +0 -0
  93. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_mh.py +0 -0
  94. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_pcn.py +0 -0
  95. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_rto.py +0 -0
  96. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/sampler/_sampler.py +0 -0
  97. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/samples/__init__.py +0 -0
  98. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/samples/_samples.py +0 -0
  99. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/solver/__init__.py +0 -0
  100. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/solver/_solver.py +0 -0
  101. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/testproblem/__init__.py +0 -0
  102. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/testproblem/_testproblem.py +0 -0
  103. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/utilities/__init__.py +0 -0
  104. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/utilities/_get_python_variable_name.py +0 -0
  105. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/cuqi/utilities/_utilities.py +0 -0
  106. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/pyproject.toml +0 -0
  107. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/requirements.txt +0 -0
  108. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/setup.cfg +0 -0
  109. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/setup.py +0 -0
  110. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_MRFs.py +0 -0
  111. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_abstract_distribution_density.py +0 -0
  112. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_bayesian_inversion.py +0 -0
  113. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_density.py +0 -0
  114. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_distribution.py +0 -0
  115. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_distributions_shape.py +0 -0
  116. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_geometry.py +0 -0
  117. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_joint_distribution.py +0 -0
  118. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_likelihood.py +0 -0
  119. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_model.py +0 -0
  120. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_pde.py +0 -0
  121. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_posterior.py +0 -0
  122. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_problem.py +0 -0
  123. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_sampler.py +0 -0
  124. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_samples.py +0 -0
  125. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_solver.py +0 -0
  126. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_testproblem.py +0 -0
  127. {cuqipy-1.3.0.post0.dev237 → cuqipy-1.3.0.post0.dev277}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev237
3
+ Version: 1.3.0.post0.dev277
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -45,6 +45,7 @@ cuqi/distribution/_smoothed_laplace.py
45
45
  cuqi/distribution/_truncated_normal.py
46
46
  cuqi/distribution/_uniform.py
47
47
  cuqi/experimental/__init__.py
48
+ cuqi/experimental/_recommender.py
48
49
  cuqi/experimental/algebra/__init__.py
49
50
  cuqi/experimental/algebra/_ast.py
50
51
  cuqi/experimental/algebra/_orderedset.py
@@ -64,7 +65,6 @@ cuqi/experimental/mcmc/_mh.py
64
65
  cuqi/experimental/mcmc/_pcn.py
65
66
  cuqi/experimental/mcmc/_rto.py
66
67
  cuqi/experimental/mcmc/_sampler.py
67
- cuqi/experimental/mcmc/_utilities.py
68
68
  cuqi/geometry/__init__.py
69
69
  cuqi/geometry/_geometry.py
70
70
  cuqi/implicitprior/__init__.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: CUQIpy
3
- Version: 1.3.0.post0.dev237
3
+ Version: 1.3.0.post0.dev277
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2025-05-15T09:06:20+0300",
11
+ "date": "2025-06-26T10:10:00+0300",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "b9a5dd37abfe2ee51d5594630de54dd3469bfeca",
15
- "version": "1.3.0.post0.dev237"
14
+ "full-revisionid": "52c8f8ffc38956c4aee969ccc70cd3db24ee774e",
15
+ "version": "1.3.0.post0.dev277"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -2,3 +2,4 @@
2
2
  from . import mcmc
3
3
  from . import algebra
4
4
  from . import geometry
5
+ from ._recommender import SamplerRecommender
@@ -0,0 +1,200 @@
1
+ import cuqi
2
+ import inspect
3
+ import numpy as np
4
+
5
+ # This import makes suggest_sampler easier to read
6
+ import cuqi.experimental.mcmc as samplers
7
+
8
+
9
+ class SamplerRecommender(object):
10
+ """
11
+ This class can be used to automatically choose a sampler.
12
+
13
+ Parameters
14
+ ----------
15
+ target: Density or JointDistribution
16
+ Distribution to get sampler recommendations for.
17
+
18
+ exceptions: list[cuqi.experimental.mcmc.Sampler], *optional*
19
+ Samplers not to be recommended.
20
+ """
21
+
22
+ def __init__(self, target:cuqi.density.Density, exceptions = []):
23
+ self._target = target
24
+ self._exceptions = exceptions
25
+ self._create_ordering()
26
+
27
+ @property
28
+ def target(self) -> cuqi.density.Density:
29
+ """ Return the target Distribution. """
30
+ return self._target
31
+
32
+ @target.setter
33
+ def target(self, value:cuqi.density.Density):
34
+ """ Set the target Distribution. Runs validation of the target. """
35
+ if value is None:
36
+ raise ValueError("Target needs to be of type cuqi.density.Density.")
37
+ self._target = value
38
+
39
+ def _create_ordering(self):
40
+ """
41
+ Every element in the ordering consists of a tuple:
42
+ (
43
+ Sampler: Class
44
+ boolean: additional conditions on the target
45
+ parameters: additional parameters to be passed to the sampler once initialized
46
+ )
47
+ """
48
+ number_of_components = np.sum(self._target.dim)
49
+
50
+ self._ordering = [
51
+ # Direct and Conjugate samplers
52
+ (samplers.Direct, True, {}),
53
+ (samplers.Conjugate, True, {}),
54
+ (samplers.ConjugateApprox, True, {}),
55
+ # Specialized samplers
56
+ (samplers.LinearRTO, True, {}),
57
+ (samplers.RegularizedLinearRTO, True, {}),
58
+ (samplers.UGLA, True, {}),
59
+ # Gradient.based samplers (Hamiltonian and Langevin)
60
+ (samplers.NUTS, True, {}),
61
+ (samplers.MALA, True, {}),
62
+ (samplers.ULA, True, {}),
63
+ # Gibbs and Componentwise samplers
64
+ (samplers.HybridGibbs, True, {"sampling_strategy" : self.recommend_HybridGibbs_sampling_strategy(as_string = False)}),
65
+ (samplers.CWMH, number_of_components <= 100, {"scale" : 0.05*np.ones(number_of_components),
66
+ "initial_point" : 0.5*np.ones(number_of_components)}),
67
+ # Proposal based samplers
68
+ (samplers.PCN, True, {"scale" : 0.02}),
69
+ (samplers.MH, number_of_components <= 1000, {}),
70
+ ]
71
+
72
+ @property
73
+ def ordering(self):
74
+ """ Returns the ordered list of recommendation rules used by the recommender. """
75
+ return self._ordering
76
+
77
+ def valid_samplers(self, as_string = True):
78
+ """
79
+ Finds all possible samplers that can be used for sampling from the target distribution.
80
+
81
+ Parameters
82
+ ----------
83
+
84
+ as_string : boolean
85
+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
86
+
87
+ """
88
+
89
+ all_samplers = [(name, cls) for name, cls in inspect.getmembers(cuqi.experimental.mcmc, inspect.isclass) if issubclass(cls, cuqi.experimental.mcmc.Sampler)]
90
+ valid_samplers = []
91
+
92
+ for name, sampler in all_samplers:
93
+ try:
94
+ sampler(self.target)
95
+ valid_samplers += [name if as_string else sampler]
96
+ except:
97
+ pass
98
+
99
+ # Need a separate case for HybridGibbs
100
+ if self.valid_HybridGibbs_sampling_strategy() is not None:
101
+ valid_samplers += [cuqi.experimental.mcmc.HybridGibbs.__name__ if as_string else cuqi.experimental.mcmc.HybridGibbs]
102
+
103
+ return valid_samplers
104
+
105
+
106
+ def valid_HybridGibbs_sampling_strategy(self, as_string = True):
107
+ """
108
+ Find all possible sampling strategies to be used with the HybridGibbs sampler.
109
+ Returns None if no sampler could be suggested for at least one conditional distribution.
110
+
111
+ Parameters
112
+ ----------
113
+
114
+ as_string : boolean
115
+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
116
+
117
+
118
+ """
119
+
120
+ if not isinstance(self.target, cuqi.distribution.JointDistribution):
121
+ return None
122
+
123
+ par_names = self.target.get_parameter_names()
124
+
125
+ valid_samplers = dict()
126
+ for par_name in par_names:
127
+ conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
128
+ conditional = self.target(**conditional_params)
129
+
130
+ recommender = SamplerRecommender(conditional)
131
+ samplers = recommender.valid_samplers(as_string)
132
+ if len(samplers) == 0:
133
+ return None
134
+
135
+ valid_samplers[par_name] = samplers
136
+
137
+ return valid_samplers
138
+
139
+
140
+ def recommend(self, as_string = False):
141
+ """
142
+ Suggests a possible sampler that can be used for sampling from the target distribution.
143
+ Return None if no sampler could be suggested.
144
+
145
+ Parameters
146
+ ----------
147
+
148
+ as_string : boolean
149
+ Whether to return the name of the sampler as a string instead of instantiating a sampler. *Optional*
150
+
151
+ """
152
+
153
+ valid_samplers = self.valid_samplers(as_string = False)
154
+
155
+ for suggestion, flag, values in self._ordering:
156
+ if flag and (suggestion in valid_samplers) and (suggestion not in self._exceptions):
157
+ # Sampler found
158
+ if as_string:
159
+ return suggestion.__name__
160
+ else:
161
+ return suggestion(self.target, **values)
162
+
163
+ # No sampler can be suggested
164
+ raise ValueError("Cannot suggest any sampler. Either the provided distribution is incorrectly defined or there are too many exceptions provided.")
165
+
166
+ def recommend_HybridGibbs_sampling_strategy(self, as_string = False):
167
+ """
168
+ Suggests a possible sampling strategy to be used with the HybridGibbs sampler.
169
+ Returns None if no sampler could be suggested for at least one conditional distribution.
170
+
171
+ Parameters
172
+ ----------
173
+
174
+ target : `cuqi.distribution.JointDistribution`
175
+ The target distribution get a sampling strategy for.
176
+
177
+ as_string : boolean
178
+ Whether to return the name of the samplers in the sampling strategy as a string instead of instantiating samplers. *Optional*
179
+
180
+ """
181
+
182
+ if not isinstance(self.target, cuqi.distribution.JointDistribution):
183
+ return None
184
+
185
+ par_names = self.target.get_parameter_names()
186
+
187
+ suggested_samplers = dict()
188
+ for par_name in par_names:
189
+ conditional_params = {par_name_: np.ones(self.target.dim[i]) for i, par_name_ in enumerate(par_names) if par_name_ != par_name}
190
+ conditional = self.target(**conditional_params)
191
+
192
+ recommender = SamplerRecommender(conditional, exceptions = self._exceptions.copy())
193
+ sampler = recommender.recommend(as_string = as_string)
194
+
195
+ if sampler is None:
196
+ return None
197
+
198
+ suggested_samplers[par_name] = sampler
199
+
200
+ return suggested_samplers
@@ -120,4 +120,3 @@ from ._gibbs import HybridGibbs
120
120
  from ._conjugate import Conjugate
121
121
  from ._conjugate_approx import ConjugateApprox
122
122
  from ._direct import Direct
123
- from ._utilities import find_valid_samplers
@@ -42,6 +42,10 @@ class HybridGibbs:
42
42
  fully stateful at this point. This means samplers like NUTS will lose
43
43
  their internal state between Gibbs steps.
44
44
 
45
+ The order in which the conditionals are sampled is the order of the
46
+ variables in the sampling strategy, unless a different sampling order
47
+ is specified by the parameter `scan_order`
48
+
45
49
  Parameters
46
50
  ----------
47
51
  target : cuqi.distribution.JointDistribution
@@ -58,6 +62,11 @@ class HybridGibbs:
58
62
  will call its step method in each Gibbs step.
59
63
  Default is 1 for all variables.
60
64
 
65
+ scan_order : list or str, *optional*
66
+ Order in which the conditional distributions are sampled.
67
+ If set to "random", use a random ordering at each step.
68
+ If not specified, it will be the order in the sampling_strategy.
69
+
61
70
  callback : callable, optional
62
71
  A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
63
72
  The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
@@ -107,7 +116,7 @@ class HybridGibbs:
107
116
 
108
117
  """
109
118
 
110
- def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, callback=None):
119
+ def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
111
120
 
112
121
  # Store target and allow conditioning to reduce to a single density
113
122
  self.target = target() # Create a copy of target distribution (to avoid modifying the original)
@@ -121,6 +130,13 @@ class HybridGibbs:
121
130
  # Store parameter names
122
131
  self.par_names = self.target.get_parameter_names()
123
132
 
133
+ # Store the scan order
134
+ self._scan_order = scan_order
135
+
136
+ # Check that the parameters of the target align with the sampling_strategy and scan_order
137
+ if set(self.par_names) != set(self.scan_order):
138
+ raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
139
+
124
140
  # Initialize sampler (after target is set)
125
141
  self._initialize()
126
142
 
@@ -148,6 +164,16 @@ class HybridGibbs:
148
164
  # Validate all targets for samplers.
149
165
  self.validate_targets()
150
166
 
167
+ @property
168
+ def scan_order(self):
169
+ if self._scan_order is None:
170
+ return list(self.samplers.keys())
171
+ if self._scan_order == "random":
172
+ arr = list(self.samplers.keys())
173
+ np.random.shuffle(arr) # Shuffle works in-place
174
+ return arr
175
+ return self._scan_order
176
+
151
177
  # ------------ Public methods ------------
152
178
  def validate_targets(self):
153
179
  """ Validate each of the conditional targets used in the Gibbs steps """
@@ -217,7 +243,7 @@ class HybridGibbs:
217
243
  """ Sequentially go through all parameters and sample them conditionally on each other """
218
244
 
219
245
  # Sample from each conditional distribution
220
- for par_name in self.par_names:
246
+ for par_name in self.scan_order:
221
247
 
222
248
  # Set target for current parameter
223
249
  self._set_target(par_name)
@@ -210,8 +210,8 @@ def test_regression_increasing():
210
210
  posterior = joint(y=y_obs)
211
211
 
212
212
  sampling_strategy = {
213
+ 'd': cuqi.experimental.mcmc.Conjugate(),
213
214
  'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False),
214
- 'd': cuqi.experimental.mcmc.Conjugate()
215
215
  }
216
216
  sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
217
217
 
@@ -241,8 +241,8 @@ def test_regression_convex():
241
241
  posterior = joint(y=y_obs)
242
242
 
243
243
  sampling_strategy = {
244
- 'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False),
245
- 'd': cuqi.experimental.mcmc.Conjugate()
244
+ 'd': cuqi.experimental.mcmc.Conjugate(),
245
+ 'x': cuqi.experimental.mcmc.RegularizedLinearRTO(maxit=50, penalty_parameter=20, adaptive = False)
246
246
  }
247
247
  sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
248
248
 
@@ -1,17 +0,0 @@
1
- import cuqi
2
- import inspect
3
-
4
- def find_valid_samplers(target):
5
- """ Finds all samplers in the cuqi.experimental.mcmc module that accept the provided target. """
6
-
7
- all_samplers = [(name, cls) for name, cls in inspect.getmembers(cuqi.experimental.mcmc, inspect.isclass) if issubclass(cls, cuqi.experimental.mcmc.Sampler)]
8
- valid_samplers = []
9
-
10
- for name, sampler in all_samplers:
11
- try:
12
- sampler(target)
13
- valid_samplers += [name]
14
- except:
15
- pass
16
-
17
- return valid_samplers