CUQIpy 1.2.0.post0.dev342__tar.gz → 1.2.0.post0.dev371__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (125) hide show
  1. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/CUQIpy.egg-info/SOURCES.txt +2 -0
  3. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/PKG-INFO +1 -1
  4. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/_version.py +3 -3
  5. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_joint_distribution.py +15 -3
  6. cuqipy-1.2.0.post0.dev371/cuqi/experimental/algebra/__init__.py +2 -0
  7. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/algebra/_ast.py +14 -0
  8. cuqipy-1.2.0.post0.dev371/cuqi/experimental/algebra/_orderedset.py +59 -0
  9. cuqipy-1.2.0.post0.dev371/cuqi/experimental/algebra/_randomvariable.py +360 -0
  10. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/implicitprior/_restorator.py +18 -6
  11. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/model/_model.py +25 -0
  12. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/utilities/_get_python_variable_name.py +2 -2
  13. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_implicit_priors.py +27 -1
  14. cuqipy-1.2.0.post0.dev342/cuqi/experimental/algebra/__init__.py +0 -1
  15. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/CUQIpy.egg-info/dependency_links.txt +0 -0
  16. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/CUQIpy.egg-info/requires.txt +0 -0
  17. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/CUQIpy.egg-info/top_level.txt +0 -0
  18. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/LICENSE +0 -0
  19. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/README.md +0 -0
  20. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/__init__.py +0 -0
  21. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/_messages.py +0 -0
  22. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/array/__init__.py +0 -0
  23. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/array/_array.py +0 -0
  24. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/config.py +0 -0
  25. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/__init__.py +0 -0
  26. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/_data.py +0 -0
  27. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/astronaut.npz +0 -0
  28. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/camera.npz +0 -0
  29. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/cat.npz +0 -0
  30. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/cookie.png +0 -0
  31. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/data/satellite.mat +0 -0
  32. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/density/__init__.py +0 -0
  33. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/density/_density.py +0 -0
  34. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/diagnostics.py +0 -0
  35. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/__init__.py +0 -0
  36. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_beta.py +0 -0
  37. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_cauchy.py +0 -0
  38. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_cmrf.py +0 -0
  39. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_custom.py +0 -0
  40. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_distribution.py +0 -0
  41. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_gamma.py +0 -0
  42. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_gaussian.py +0 -0
  43. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_gmrf.py +0 -0
  44. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_inverse_gamma.py +0 -0
  45. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_laplace.py +0 -0
  46. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_lmrf.py +0 -0
  47. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_lognormal.py +0 -0
  48. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  49. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_normal.py +0 -0
  50. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_posterior.py +0 -0
  51. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_smoothed_laplace.py +0 -0
  52. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_truncated_normal.py +0 -0
  53. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/distribution/_uniform.py +0 -0
  54. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/__init__.py +0 -0
  55. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/__init__.py +0 -0
  56. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  57. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  58. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  59. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_direct.py +0 -0
  60. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_gibbs.py +0 -0
  61. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_hmc.py +0 -0
  62. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  63. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  64. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_mh.py +0 -0
  65. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_pcn.py +0 -0
  66. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_rto.py +0 -0
  67. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_sampler.py +0 -0
  68. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/experimental/mcmc/_utilities.py +0 -0
  69. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/geometry/__init__.py +0 -0
  70. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/geometry/_geometry.py +0 -0
  71. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/implicitprior/__init__.py +0 -0
  72. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  73. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  74. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  75. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/likelihood/__init__.py +0 -0
  76. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/likelihood/_likelihood.py +0 -0
  77. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/model/__init__.py +0 -0
  78. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/operator/__init__.py +0 -0
  79. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/operator/_operator.py +0 -0
  80. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/pde/__init__.py +0 -0
  81. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/pde/_pde.py +0 -0
  82. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/problem/__init__.py +0 -0
  83. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/problem/_problem.py +0 -0
  84. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/__init__.py +0 -0
  85. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_conjugate.py +0 -0
  86. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_conjugate_approx.py +0 -0
  87. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_cwmh.py +0 -0
  88. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_gibbs.py +0 -0
  89. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_hmc.py +0 -0
  90. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_langevin_algorithm.py +0 -0
  91. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_laplace_approximation.py +0 -0
  92. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_mh.py +0 -0
  93. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_pcn.py +0 -0
  94. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_rto.py +0 -0
  95. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/sampler/_sampler.py +0 -0
  96. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/samples/__init__.py +0 -0
  97. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/samples/_samples.py +0 -0
  98. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/solver/__init__.py +0 -0
  99. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/solver/_solver.py +0 -0
  100. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/testproblem/__init__.py +0 -0
  101. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/testproblem/_testproblem.py +0 -0
  102. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/utilities/__init__.py +0 -0
  103. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/cuqi/utilities/_utilities.py +0 -0
  104. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/pyproject.toml +0 -0
  105. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/requirements.txt +0 -0
  106. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/setup.cfg +0 -0
  107. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/setup.py +0 -0
  108. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_MRFs.py +0 -0
  109. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_abstract_distribution_density.py +0 -0
  110. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_bayesian_inversion.py +0 -0
  111. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_density.py +0 -0
  112. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_distribution.py +0 -0
  113. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_distributions_shape.py +0 -0
  114. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_geometry.py +0 -0
  115. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_joint_distribution.py +0 -0
  116. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_likelihood.py +0 -0
  117. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_model.py +0 -0
  118. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_pde.py +0 -0
  119. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_posterior.py +0 -0
  120. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_problem.py +0 -0
  121. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_sampler.py +0 -0
  122. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_samples.py +0 -0
  123. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_solver.py +0 -0
  124. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_testproblem.py +0 -0
  125. {cuqipy-1.2.0.post0.dev342 → cuqipy-1.2.0.post0.dev371}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev342
3
+ Version: 1.2.0.post0.dev371
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -47,6 +47,8 @@ cuqi/distribution/_uniform.py
47
47
  cuqi/experimental/__init__.py
48
48
  cuqi/experimental/algebra/__init__.py
49
49
  cuqi/experimental/algebra/_ast.py
50
+ cuqi/experimental/algebra/_orderedset.py
51
+ cuqi/experimental/algebra/_randomvariable.py
50
52
  cuqi/experimental/mcmc/__init__.py
51
53
  cuqi/experimental/mcmc/_conjugate.py
52
54
  cuqi/experimental/mcmc/_conjugate_approx.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev342
3
+ Version: 1.2.0.post0.dev371
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-25T09:18:51+0100",
11
+ "date": "2024-12-10T23:09:59+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "63395d14f6c3f964633b20200ada13b8a213da20",
15
- "version": "1.2.0.post0.dev342"
14
+ "full-revisionid": "6325eb7e648f3d6f801195a1d06a2d67928479f5",
15
+ "version": "1.2.0.post0.dev371"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -5,6 +5,7 @@ from cuqi.density import Density, EvaluatedDensity
5
5
  from cuqi.distribution import Distribution, Posterior
6
6
  from cuqi.likelihood import Likelihood
7
7
  from cuqi.geometry import Geometry, _DefaultGeometry1D
8
+ import cuqi
8
9
  import numpy as np # for splitting array. Can avoid.
9
10
 
10
11
  class JointDistribution:
@@ -13,9 +14,11 @@ class JointDistribution:
13
14
 
14
15
  Parameters
15
16
  ----------
16
- densities : Density
17
+ densities : RandomVariable or Density
17
18
  The densities to include in the joint distribution.
18
- Each density is passed as comma-separated arguments.
19
+ Each density is passed as comma-separated arguments,
20
+ and can be either a :class:'Density' such as :class:'Distribution'
21
+ or :class:`RandomVariable`.
19
22
 
20
23
  Notes
21
24
  -----
@@ -59,7 +62,16 @@ class JointDistribution:
59
62
  posterior = joint(y=y_obs)
60
63
 
61
64
  """
62
- def __init__(self, *densities: Density):
65
+ def __init__(self, *densities: [Density, cuqi.experimental.algebra.RandomVariable]):
66
+ """ Create a joint distribution from the given densities. """
67
+
68
+ # Check if all RandomVariables are simple (not-transformed)
69
+ for density in densities:
70
+ if isinstance(density, cuqi.experimental.algebra.RandomVariable) and density.is_transformed:
71
+ raise ValueError(f"To be used in {self.__class__.__name__}, all RandomVariables must be untransformed.")
72
+
73
+ # Convert potential random variables to their underlying distribution
74
+ densities = [density.distribution if isinstance(density, cuqi.experimental.algebra.RandomVariable) else density for density in densities]
63
75
 
64
76
  # Ensure all densities have unique names
65
77
  names = [density.name for density in densities]
@@ -0,0 +1,2 @@
1
+ from ._ast import VariableNode, Node
2
+ from ._randomvariable import RandomVariable
@@ -56,6 +56,20 @@ class Node(ABC):
56
56
  """String representation of the node. Used for printing the AST."""
57
57
  pass
58
58
 
59
+ def get_variables(self, variables=None):
60
+ """Returns a set with the names of all variables in the sub-tree originated at this node."""
61
+ if variables is None:
62
+ variables = set()
63
+ if isinstance(self, VariableNode):
64
+ variables.add(self.name)
65
+ if hasattr(self, "child"):
66
+ self.child.get_variables(variables)
67
+ if hasattr(self, "left"):
68
+ self.left.get_variables(variables)
69
+ if hasattr(self, "right"):
70
+ self.right.get_variables(variables)
71
+ return variables
72
+
59
73
  def __add__(self, other):
60
74
  return AddNode(self, convert_to_node(other))
61
75
 
@@ -0,0 +1,59 @@
1
+ class _OrderedSet:
2
+ """A set (i.e. unique elements) that keeps its elements in the order they were added.
3
+
4
+ This is a minimal implementation of an ordered set, using a dictionary for storage.
5
+ """
6
+
7
+ def __init__(self, iterable=None):
8
+ """Initialize the OrderedSet.
9
+
10
+ If an iterable is provided, add all its elements to the set.
11
+ """
12
+ self.dict = dict.fromkeys(iterable if iterable else [])
13
+
14
+ def add(self, item):
15
+ """Add an item to the set.
16
+
17
+ If the item is already in the set, it does nothing.
18
+ Otherwise, the item is stored as a key in the dictionary, with None as its value.
19
+ """
20
+ self.dict[item] = None
21
+
22
+ def __contains__(self, item):
23
+ """Check if an item is in the set.
24
+
25
+ This is equivalent to checking if the item is a key in the dictionary.
26
+ """
27
+ return item in self.dict
28
+
29
+ def __iter__(self):
30
+ """Return an iterator over the set.
31
+
32
+ This iterates over the keys in the dictionary.
33
+ """
34
+ return iter(self.dict)
35
+
36
+ def __len__(self):
37
+ """Return the number of items in the set."""
38
+ return len(self.dict)
39
+
40
+ def extend(self, other):
41
+ """Extend the set with the items in another set.
42
+
43
+ Raises a TypeError if the other object is not an _OrderedSet.
44
+ """
45
+ if not isinstance(other, _OrderedSet):
46
+ raise TypeError("unsupported operand type(s) for extend: '_OrderedSet' and '{}'".format(type(other).__name__))
47
+ for item in other:
48
+ self.add(item)
49
+
50
+ def __or__(self, other):
51
+ """Return a new set that is the union of this set and another set.
52
+
53
+ Raises a TypeError if the other object is not an _OrderedSet.
54
+ """
55
+ if not isinstance(other, _OrderedSet):
56
+ raise TypeError("unsupported operand type(s) for |: '_OrderedSet' and '{}'".format(type(other).__name__))
57
+ new_set = _OrderedSet(self.dict.keys())
58
+ new_set.extend(other)
59
+ return new_set
@@ -0,0 +1,360 @@
1
+ from __future__ import annotations
2
+ from typing import List, Any, Union
3
+ from ._ast import VariableNode, Node
4
+ from ._orderedset import _OrderedSet
5
+ import operator
6
+ import cuqi
7
+ from cuqi.distribution import Distribution
8
+ from copy import copy
9
+
10
+
11
+ class RandomVariable:
12
+ """ Random variable defined by a distribution with the option to apply algebraic operations on it.
13
+
14
+ Random variables allow for the definition of Bayesian Problems in a natural way. In the context
15
+ of code, the random variable can be viewed as a lazily evaluated variable/array. It records
16
+ operations applied to it and acts as a function that, when called, evaluates the operations
17
+ and returns the result.
18
+
19
+ In CUQIpy, random variables can be in two forms: (1) a 'primal' random variable that is directly
20
+ defined by a distribution, e.g. x ~ N(0, 1), or (2) a 'transformed' random variable that is defined by
21
+ applying algebraic operations on one or more random variables, e.g. y = x + 1.
22
+
23
+ This distinction is purely for the purpose of the implementation in CUQIpy, as mathematically both
24
+ x ~ N(0, 1) and y = x + 1 ~ N(1, 1) are random variables. The distinction is useful for the
25
+ code implementation. In the future some operations like the above may allow primal random variables
26
+ that are transformed if the distribution can be analytically described.
27
+
28
+ Parameters
29
+ ----------
30
+ distributions : Distribution or list of Distributions
31
+ The distribution from which the random variable originates. If multiple distributions are
32
+ provided, the random variable is defined by the passed abstract syntax `tree` representing the
33
+ algebraic operations applied to one or more random variables.
34
+
35
+ tree : Node, optional
36
+ The tree, represented by the syntax tree nodes, that contain the algebraic operations applied to the random variable.
37
+ Specifically, the root of the tree should be provided.
38
+
39
+ name : str, optional
40
+ Name of the random variable. If not provided, the name is extracted from either the distribution provided
41
+ or from the variable name in the code. The name provided must match the parameter name of the distribution.
42
+
43
+ Example
44
+ -------
45
+
46
+ Basic usage:
47
+
48
+ .. code-block:: python
49
+
50
+ from cuqi.distribution import Gaussian
51
+
52
+ x = RandomVariable(Gaussian(0, 1))
53
+
54
+ Defining Bayesian problem using random variables:
55
+
56
+ .. code-block:: python
57
+
58
+ from cuqi.testproblem import Deconvolution1D
59
+ from cuqi.distribution import Gaussian, Gamma, GMRF
60
+ from cuqi.experimental.algebra import RandomVariable
61
+ from cuqi.problem import BayesianProblem
62
+
63
+ import numpy as np
64
+ A, y_obs, info = Deconvolution1D().get_components()
65
+
66
+ # Bayesian problem
67
+ d = RandomVariable(Gamma(1, 1e-4))
68
+ s = RandomVariable(Gamma(1, 1e-4))
69
+ x = RandomVariable(GMRF(np.zeros(A.domain_dim), d))
70
+ y = RandomVariable(Gaussian(A @ x, 1/s))
71
+
72
+ BP = BayesianProblem(y, x, s, d)
73
+ BP.set_data(y=y_obs)
74
+ BP.UQ()
75
+
76
+ Defining random variable from multiple distributions:
77
+
78
+ .. code-block:: python
79
+
80
+ from cuqi.distribution import Gaussian, Gamma
81
+ from cuqi.experimental.algebra import RandomVariable, VariableNode
82
+
83
+ # Define the variables
84
+ x = VariableNode('x')
85
+ y = VariableNode('y')
86
+
87
+ # Define the distributions (names must match variables)
88
+ dist_x = Gaussian(0, 1, name='x')
89
+ dist_y = Gamma(1, 1e-4, name='y')
90
+
91
+ # Define the tree (this is the algebra that defines the random variable along with the distributions)
92
+ tree = x + y
93
+
94
+ # Define random variable from 2 distributions with relation x+y
95
+ rv = RandomVariable([dist_x, dist_y], tree)
96
+
97
+ """
98
+
99
+
100
+ def __init__(self, distributions: Union['Distribution', List['Distribution']], tree: 'Node' = None, name: str = None):
101
+ """ Create random variable from distribution """
102
+
103
+ if isinstance(distributions, Distribution):
104
+ distributions = [distributions]
105
+
106
+ if not isinstance(distributions, list) and not isinstance(distributions, _OrderedSet):
107
+ raise ValueError("Expected a distribution or a list of distributions")
108
+
109
+ # Convert single distribution(s) to internal datastructure _OrderedSet.
110
+ # We use ordered set to ensure that the order of the distributions is preserved.
111
+ # which in turn ensures that the parameter names are always in the same order.
112
+ if not isinstance(distributions, _OrderedSet):
113
+ distributions = _OrderedSet(distributions)
114
+
115
+ # If tree is provided, check it is consistent with the given distributions
116
+ if tree:
117
+ tree_var_names = tree.get_variables()
118
+ dist_par_names = {dist._name for dist in distributions}
119
+
120
+ if len(tree_var_names) != len(distributions):
121
+ raise ValueError(
122
+ f"There are {len(tree_var_names)} variables in the tree, but {len(distributions)} distributions are provided. "
123
+ "This may be due to passing multiple distributions with the same parameter name. "
124
+ f"The tree variables are {tree_var_names} and the distribution parameter names are {dist_par_names}."
125
+ )
126
+
127
+ if not all(var_name in dist_par_names for var_name in tree_var_names):
128
+ raise ValueError(
129
+ f"Variable names in the tree {tree_var_names} do not match the parameter names in the distributions {dist_par_names}. "
130
+ "Ensure the name is inferred from the variable or explicitly provide it using name='var_name' in the distribution."
131
+ )
132
+
133
+ # Match random variable name with distribution parameter name (for single distribution)
134
+ if len(distributions) == 1 and tree is None:
135
+ dist = next(iter(distributions))
136
+ dist_par_name = dist._name
137
+ if dist_par_name is not None:
138
+ if name is not None and dist_par_name != name:
139
+ raise ValueError(f"Parameter name '{dist_par_name}' of the distribution does not match the input name '{name}' for the random variable.")
140
+ name = dist_par_name
141
+
142
+ self._distributions = distributions
143
+ """ The distribution from which the random variable originates. """
144
+
145
+ self._tree = tree
146
+ """ The tree representation of the random variable. """
147
+
148
+ self._original_variable = None
149
+ """ Stores the original variable if this is a conditioned copy"""
150
+
151
+ self._name = name
152
+ """ Name of the random variable. """
153
+
154
+
155
+ def __call__(self, *args, **kwargs) -> Any:
156
+ """ Evaluate random variable at a given parameter value. For example, for random variable `X`, `X(1)` gives `1` and `(X+1)(1)` gives `2` """
157
+
158
+ if args and kwargs:
159
+ raise ValueError("Cannot pass both positional and keyword arguments to RandomVariable")
160
+
161
+ if args:
162
+ kwargs = self._parse_args_add_to_kwargs(args, kwargs)
163
+
164
+ # Check if kwargs match parameter names using a all compare
165
+ if not all([name in kwargs for name in self.parameter_names]) or not all([name in self.parameter_names for name in kwargs]):
166
+ raise ValueError(f"Expected arguments {self.parameter_names}, got arguments {kwargs}")
167
+
168
+ return self.tree(**kwargs)
169
+
170
+ @property
171
+ def tree(self):
172
+ if self._tree is None:
173
+ if len(self._distributions) > 1:
174
+ raise ValueError("Tree for multiple distributions can not be created automatically and need to be passed as an argument to the {} initializer.".format(type(self).__name__))
175
+ self._tree = VariableNode(self.name)
176
+ return self._tree
177
+
178
+ @property
179
+ def name(self):
180
+ """ Name of the random variable. If not provided, the name is extracted from the variable name in the code. """
181
+ if self._is_copy: # Extract the original variable name if this is a copy
182
+ return self._original_variable.name
183
+ if self._name is None: # If None extract the name from the stack
184
+ self._name = cuqi.utilities._get_python_variable_name(self)
185
+ if self._name is not None:
186
+ self._inject_name_into_distribution(self._name)
187
+ return self._name
188
+
189
+ @name.setter
190
+ def name(self, name):
191
+ if self._is_copy:
192
+ raise ValueError("This random variable is derived from the conditional random variable named "+self._original_variable.name+". The name of the derived random variable cannot be set, but follows the name of the original random variable.")
193
+ self._name = name
194
+
195
+ @property
196
+ def distribution(self) -> cuqi.distribution.Distribution:
197
+ """ Distribution from which the random variable originates. """
198
+ if len(self._distributions) > 1:
199
+ raise ValueError("Cannot get distribution from random variable defined by multiple distributions")
200
+ self._inject_name_into_distribution()
201
+ return next(iter(self._distributions))
202
+
203
+ @property
204
+ def distributions(self) -> set:
205
+ """ Distributions from which the random variable originates. """
206
+ self._inject_name_into_distribution()
207
+ return self._distributions
208
+
209
+ @property
210
+ def parameter_names(self) -> str:
211
+ """ Name of the parameter that the random variable can be evaluated at. """
212
+ self._inject_name_into_distribution()
213
+ return [distribution.name for distribution in self.distributions] # Consider renaming .name to .par_name for distributions
214
+
215
+ @property
216
+ def dim(self):
217
+ if self.is_transformed:
218
+ raise NotImplementedError("Dimension not implemented for transformed random variables")
219
+ return self.distribution.dim
220
+
221
+ @property
222
+ def geometry(self):
223
+ if self.is_transformed:
224
+ raise NotImplementedError("Geometry not implemented for transformed random variables")
225
+ return self.distribution.geometry
226
+
227
+ @geometry.setter
228
+ def geometry(self, geometry):
229
+ if self.is_transformed:
230
+ raise NotImplementedError("Geometry not implemented for transformed random variables")
231
+ self.distribution.geometry = geometry
232
+
233
+ @property
234
+ def expression(self):
235
+ """ Expression (formula) of the random variable. """
236
+ return str(self.tree)
237
+
238
+ @property
239
+ def is_transformed(self):
240
+ """ Returns True if the random variable is transformed. """
241
+ return not isinstance(self.tree, VariableNode)
242
+
243
+ @property
244
+ def _non_default_args(self) -> List[str]:
245
+ """List of non-default arguments to distribution. This is used to return the correct
246
+ arguments when evaluating the random variable.
247
+ """
248
+ return self.parameter_names
249
+
250
+ def _inject_name_into_distribution(self, name=None):
251
+ if len(self._distributions) == 1:
252
+ dist = next(iter(self._distributions))
253
+ if dist._name is None:
254
+ if name is None:
255
+ name = self.name
256
+ dist._name = name
257
+
258
+ def _parse_args_add_to_kwargs(self, args, kwargs) -> dict:
259
+ """ Parse args and add to kwargs if any. Arguments follow self.parameter_names order. """
260
+ if len(args) != len(self.parameter_names):
261
+ raise ValueError(f"Expected {len(self.parameter_names)} arguments, got {len(args)}. Parameters are: {self.parameter_names}")
262
+
263
+ # Add args to kwargs
264
+ for arg, name in zip(args, self.parameter_names):
265
+ kwargs[name] = arg
266
+
267
+ return kwargs
268
+
269
+ def __repr__(self):
270
+ # Create strings for parameter name ~ distribution pairs
271
+ parameter_strings = [f"{name} ~ {distribution}" for name, distribution in zip(self.parameter_names, self.distributions)]
272
+ # Join strings with newlines
273
+ parameter_strings = "\n".join(parameter_strings)
274
+ # Add initial newline and indentations
275
+ parameter_strings = "\n".join(["\t"+line for line in parameter_strings.split("\n")])
276
+ # Print parameter strings with newlines
277
+ if self.is_transformed:
278
+ title = f"Transformed Random Variable"
279
+ else:
280
+ title = f""
281
+ if self.is_transformed:
282
+ body = (
283
+ f"\n"
284
+ f"Expression: {self.tree}\n"
285
+ f"Components: \n{parameter_strings}"
286
+ )
287
+ else:
288
+ body = parameter_strings.replace("\t","")
289
+ return title+body
290
+
291
+ @property
292
+ def _is_copy(self):
293
+ """ Returns True if this is a copy of another random variable, e.g. by conditioning. """
294
+ return hasattr(self, '_original_variable') and self._original_variable is not None
295
+
296
+ def _make_copy(self):
297
+ """ Returns a shallow copy of the density keeping a pointer to the original. """
298
+ new_variable = copy(self)
299
+ new_variable._distributions = copy(self.distributions)
300
+ new_variable._tree = copy(self._tree)
301
+ new_variable._original_variable = self
302
+ return new_variable
303
+
304
+ def _apply_operation(self, operation, other=None) -> 'RandomVariable':
305
+ """
306
+ Apply a specified operation to this RandomVariable.
307
+ """
308
+ if isinstance(other, cuqi.distribution.Distribution):
309
+ raise ValueError("Cannot apply operation to distribution. Use .rv to create random variable first.")
310
+ if other is None: # unary operation case
311
+ return RandomVariable(self.distributions, operation(self.tree))
312
+ elif isinstance(other, RandomVariable): # binary operation case with another random variable that has distributions
313
+ return RandomVariable(self.distributions | other.distributions, operation(self.tree, other.tree))
314
+ return RandomVariable(self.distributions, operation(self.tree, other)) # binary operation case with any other object (constant)
315
+
316
+ def __add__(self, other) -> 'RandomVariable':
317
+ return self._apply_operation(operator.add, other)
318
+
319
+ def __radd__(self, other) -> 'RandomVariable':
320
+ return self.__add__(other)
321
+
322
+ def __sub__(self, other) -> 'RandomVariable':
323
+ return self._apply_operation(operator.sub, other)
324
+
325
+ def __rsub__(self, other) -> 'RandomVariable':
326
+ return self._apply_operation(lambda x, y: operator.sub(y, x), other)
327
+
328
+ def __mul__(self, other) -> 'RandomVariable':
329
+ return self._apply_operation(operator.mul, other)
330
+
331
+ def __rmul__(self, other) -> 'RandomVariable':
332
+ return self.__mul__(other)
333
+
334
+ def __truediv__(self, other) -> 'RandomVariable':
335
+ return self._apply_operation(operator.truediv, other)
336
+
337
+ def __rtruediv__(self, other) -> 'RandomVariable':
338
+ return self._apply_operation(lambda x, y: operator.truediv(y, x), other)
339
+
340
+ def __matmul__(self, other) -> 'RandomVariable':
341
+ if isinstance(other, cuqi.model.Model) and not isinstance(other, cuqi.model.LinearModel):
342
+ raise TypeError("Cannot apply matmul to non-linear models")
343
+ return self._apply_operation(operator.matmul, other)
344
+
345
+ def __rmatmul__(self, other) -> 'RandomVariable':
346
+ if isinstance(other, cuqi.model.Model) and not isinstance(other, cuqi.model.LinearModel):
347
+ raise TypeError("Cannot apply matmul to non-linear models")
348
+ return self._apply_operation(lambda x, y: operator.matmul(y, x), other)
349
+
350
+ def __neg__(self) -> 'RandomVariable':
351
+ return self._apply_operation(operator.neg)
352
+
353
+ def __abs__(self) -> 'RandomVariable':
354
+ return self._apply_operation(abs)
355
+
356
+ def __pow__(self, other) -> 'RandomVariable':
357
+ return self._apply_operation(operator.pow, other)
358
+
359
+ def __getitem__(self, other) -> 'RandomVariable':
360
+ return self._apply_operation(operator.getitem, other)
@@ -15,8 +15,10 @@ class RestorationPrior(Distribution):
15
15
  Parameters
16
16
  ----------
17
17
  restorator : callable f(x, restoration_strength)
18
- Function f that accepts input x to be restored and returns the
19
- restored version of x and information about the restoration operation.
18
+ Function f that accepts input x to be restored and returns a two-element
19
+ tuple of the restored version of x and extra information about the
20
+ restoration operation. The second element can be of any type, including
21
+ `None` in case there is no information.
20
22
 
21
23
  restorator_kwargs : dictionary
22
24
  Dictionary containing information about the restorator.
@@ -41,8 +43,9 @@ class RestorationPrior(Distribution):
41
43
  super().__init__(**kwargs)
42
44
 
43
45
  def restore(self, x, restoration_strength):
44
- """This function allows us to restore the input x and returns the
45
- restored version of x.
46
+ """This function allows us to restore the input x with the user-supplied
47
+ restorator. Extra information about the restoration operation is stored
48
+ in the `RestorationPrior` info attribute.
46
49
 
47
50
  Parameters
48
51
  ----------
@@ -54,9 +57,18 @@ class RestorationPrior(Distribution):
54
57
  restorator is a denoiser, this parameter might correspond to the
55
58
  noise level.
56
59
  """
57
- solution, info = self.restorator(x, restoration_strength=restoration_strength,
60
+ restorator_return = self.restorator(x, restoration_strength=restoration_strength,
58
61
  **self.restorator_kwargs)
59
- self.info = info
62
+
63
+ if type(restorator_return) == tuple and len(restorator_return) == 2:
64
+ solution, self.info = restorator_return
65
+ else:
66
+ raise ValueError("Unsupported return type from the user-supplied restorator function. "+
67
+ "Please ensure that the restorator function returns a two-element tuple with the "+
68
+ "restored solution as the first element and additional information about the "+
69
+ "restoration as the second element. The second element can be of any type, "+
70
+ "including `None` in case there is no particular information.")
71
+
60
72
  return solution
61
73
 
62
74
  def logpdf(self, x):
@@ -351,6 +351,17 @@ class Model(object):
351
351
  new_model._non_default_args = [x.name] # Defaults to x if distribution had no name
352
352
  return new_model
353
353
 
354
+ # If input is a random variable, we handle it separately
355
+ if isinstance(x, cuqi.experimental.algebra.RandomVariable):
356
+ return self._handle_random_variable(x)
357
+
358
+ # If input is a Node from internal abstract syntax tree, we let the Node handle the operation
359
+ # We use NotImplemented to indicate that the operation is not supported from the Model class
360
+ # in case of operations such as "@" that can be interpreted as both __matmul__ and __rmatmul__
361
+ # the operation may be delegated to the Node class.
362
+ if isinstance(x, cuqi.experimental.algebra.Node):
363
+ return NotImplemented
364
+
354
365
  # Else we apply the forward operator
355
366
  return self._apply_func(self._forward_func,
356
367
  self.range_geometry,
@@ -463,6 +474,20 @@ class Model(object):
463
474
  not type(self.domain_geometry) in _get_identity_geometries():
464
475
  raise NotImplementedError("Gradient not implemented for model {} with domain geometry {}".format(self,self.domain_geometry))
465
476
 
477
+ def _handle_random_variable(self, x):
478
+ """ Private function that handles the case of the input being a random variable. """
479
+ # If random variable is not a leaf-type node (e.g. internal node) we return NotImplemented
480
+ if not isinstance(x.tree, cuqi.experimental.algebra.VariableNode):
481
+ return NotImplemented
482
+
483
+ # In leaf-type node case we simply change the parameter name of model to match the random variable name
484
+ dist = x.distribution
485
+ if dist.dim != self.domain_dim:
486
+ raise ValueError("Attempting to match parameter name of Model with given random variable, but random variable dimension does not match model domain dimension.")
487
+
488
+ new_model = copy(self)
489
+ new_model._non_default_args = [dist.name]
490
+ return new_model
466
491
 
467
492
  def __len__(self):
468
493
  return self.range_dim
@@ -9,7 +9,7 @@ import cuqi
9
9
  def _get_python_variable_name(var):
10
10
  """ Retrieve the Python variable name of an object. Takes the first variable name appearing on the stack that is not in the ignore list. """
11
11
 
12
- ignored_var_names = ["self", "cls", "obj", "var", "_"]
12
+ ignored_var_names = ["self", "cls", "obj", "var", "_", "result", "args", "kwargs", "par_name", "name", "distribution", "dist"]
13
13
 
14
14
  # First get the stack size and loop (in reverse) through the stack
15
15
  # It can be a bit slow to loop through stack size so we limit the levels
@@ -29,7 +29,7 @@ def _get_python_variable_name(var):
29
29
  if len(var_names) > 0:
30
30
  return var_names[0]
31
31
 
32
- warnings.warn("Could not automatically find variable name for object: {}. Use keyword `name` when defining distribution to specify a name. If code runs slowly and variable name is not needed set config.MAX_STACK_SEARCH_DEPTH to 0.".format(var))
32
+ warnings.warn("Could not automatically find variable name for object. Did you assign (=) the object to a python variable? Alternatively, use keyword `name` when defining distribution to specify a name. If code runs slowly and variable name is not needed set config.MAX_STACK_SEARCH_DEPTH to 0. These names are reserved {} and should not be used as object name.".format(ignored_var_names))
33
33
 
34
34
  return None
35
35
 
@@ -28,7 +28,7 @@ def test_RegularizedGaussian_guarding_statements():
28
28
 
29
29
  with pytest.raises(ValueError, match="Projector should take 1 argument"):
30
30
  cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, projector=lambda s,z: s)
31
-
31
+
32
32
  def test_creating_restorator():
33
33
  """ Test creating the object from restorator class."""
34
34
 
@@ -38,6 +38,32 @@ def test_creating_restorator():
38
38
  assert np.allclose(restorator.restore(np.ones(4), 0.1), np.ones(4))
39
39
  assert restorator.info == True
40
40
 
41
+ def test_handling_invalid_restorator():
42
+ """ Test handling invalid restorator."""
43
+ # Invalid return type 1: None
44
+ def func_1(x, restoration_strength=0.1):
45
+ return
46
+ restore_prior_1 = cuqi.implicitprior.RestorationPrior(func_1)
47
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
48
+ restore_prior_1.restore(np.ones(4), 0.1)
49
+ # Invalid return type 2: one parameter
50
+ def func_2(x, restoration_strength=0.1):
51
+ return x
52
+ restore_prior_2 = cuqi.implicitprior.RestorationPrior(func_2)
53
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
54
+ restore_prior_2.restore(np.ones(4), 0.1)
55
+ # Invalid return type 3: tuple with 3 elements
56
+ def func_3(x, restoration_strength=0.1):
57
+ return x, None, False
58
+ restore_prior_3 = cuqi.implicitprior.RestorationPrior(func_3)
59
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
60
+ restore_prior_3.restore(np.ones(4), 0.1)
61
+ # Invalid return type 4: list with 2 elements
62
+ def func_4(x, restoration_strength=0.1):
63
+ return [x, None]
64
+ restore_prior_4 = cuqi.implicitprior.RestorationPrior(func_4)
65
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
66
+ restore_prior_4.restore(np.ones(4), 0.1)
41
67
 
42
68
  def test_creating_restorator_with_potential():
43
69
  """ Test creating the object from restorator class with a potential."""
@@ -1 +0,0 @@
1
- from ._ast import VariableNode