CUQIpy 1.2.0.post0.dev333__tar.gz → 1.2.0.post0.dev352__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (122) hide show
  1. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/CUQIpy.egg-info/PKG-INFO +1 -1
  2. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/CUQIpy.egg-info/SOURCES.txt +2 -0
  3. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/PKG-INFO +1 -1
  4. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/_version.py +3 -3
  5. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/__init__.py +1 -0
  6. cuqipy-1.2.0.post0.dev352/cuqi/experimental/algebra/__init__.py +1 -0
  7. cuqipy-1.2.0.post0.dev352/cuqi/experimental/algebra/_ast.py +325 -0
  8. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/implicitprior/_restorator.py +18 -6
  9. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_implicit_priors.py +27 -1
  10. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/CUQIpy.egg-info/dependency_links.txt +0 -0
  11. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/CUQIpy.egg-info/requires.txt +0 -0
  12. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/CUQIpy.egg-info/top_level.txt +0 -0
  13. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/LICENSE +0 -0
  14. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/README.md +0 -0
  15. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/__init__.py +0 -0
  16. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/_messages.py +0 -0
  17. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/array/__init__.py +0 -0
  18. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/array/_array.py +0 -0
  19. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/config.py +0 -0
  20. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/__init__.py +0 -0
  21. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/_data.py +0 -0
  22. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/astronaut.npz +0 -0
  23. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/camera.npz +0 -0
  24. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/cat.npz +0 -0
  25. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/cookie.png +0 -0
  26. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/data/satellite.mat +0 -0
  27. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/density/__init__.py +0 -0
  28. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/density/_density.py +0 -0
  29. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/diagnostics.py +0 -0
  30. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/__init__.py +0 -0
  31. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_beta.py +0 -0
  32. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_cauchy.py +0 -0
  33. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_cmrf.py +0 -0
  34. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_custom.py +0 -0
  35. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_distribution.py +0 -0
  36. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_gamma.py +0 -0
  37. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_gaussian.py +0 -0
  38. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_gmrf.py +0 -0
  39. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_inverse_gamma.py +0 -0
  40. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_joint_distribution.py +0 -0
  41. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_laplace.py +0 -0
  42. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_lmrf.py +0 -0
  43. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_lognormal.py +0 -0
  44. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
  45. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_normal.py +0 -0
  46. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_posterior.py +0 -0
  47. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_smoothed_laplace.py +0 -0
  48. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_truncated_normal.py +0 -0
  49. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/distribution/_uniform.py +0 -0
  50. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/__init__.py +0 -0
  51. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_conjugate.py +0 -0
  52. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
  53. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  54. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_direct.py +0 -0
  55. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_gibbs.py +0 -0
  56. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_hmc.py +0 -0
  57. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  58. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
  59. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_mh.py +0 -0
  60. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_pcn.py +0 -0
  61. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_rto.py +0 -0
  62. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_sampler.py +0 -0
  63. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/experimental/mcmc/_utilities.py +0 -0
  64. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/geometry/__init__.py +0 -0
  65. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/geometry/_geometry.py +0 -0
  66. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/implicitprior/__init__.py +0 -0
  67. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  68. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  69. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
  70. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/likelihood/__init__.py +0 -0
  71. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/likelihood/_likelihood.py +0 -0
  72. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/model/__init__.py +0 -0
  73. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/model/_model.py +0 -0
  74. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/operator/__init__.py +0 -0
  75. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/operator/_operator.py +0 -0
  76. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/pde/__init__.py +0 -0
  77. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/pde/_pde.py +0 -0
  78. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/problem/__init__.py +0 -0
  79. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/problem/_problem.py +0 -0
  80. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/__init__.py +0 -0
  81. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_conjugate.py +0 -0
  82. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_conjugate_approx.py +0 -0
  83. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_cwmh.py +0 -0
  84. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_gibbs.py +0 -0
  85. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_hmc.py +0 -0
  86. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_langevin_algorithm.py +0 -0
  87. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_laplace_approximation.py +0 -0
  88. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_mh.py +0 -0
  89. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_pcn.py +0 -0
  90. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_rto.py +0 -0
  91. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/sampler/_sampler.py +0 -0
  92. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/samples/__init__.py +0 -0
  93. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/samples/_samples.py +0 -0
  94. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/solver/__init__.py +0 -0
  95. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/solver/_solver.py +0 -0
  96. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/testproblem/__init__.py +0 -0
  97. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/testproblem/_testproblem.py +0 -0
  98. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/utilities/__init__.py +0 -0
  99. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/utilities/_get_python_variable_name.py +0 -0
  100. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/cuqi/utilities/_utilities.py +0 -0
  101. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/pyproject.toml +0 -0
  102. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/requirements.txt +0 -0
  103. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/setup.cfg +0 -0
  104. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/setup.py +0 -0
  105. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_MRFs.py +0 -0
  106. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_abstract_distribution_density.py +0 -0
  107. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_bayesian_inversion.py +0 -0
  108. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_density.py +0 -0
  109. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_distribution.py +0 -0
  110. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_distributions_shape.py +0 -0
  111. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_geometry.py +0 -0
  112. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_joint_distribution.py +0 -0
  113. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_likelihood.py +0 -0
  114. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_model.py +0 -0
  115. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_pde.py +0 -0
  116. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_posterior.py +0 -0
  117. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_problem.py +0 -0
  118. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_sampler.py +0 -0
  119. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_samples.py +0 -0
  120. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_solver.py +0 -0
  121. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_testproblem.py +0 -0
  122. {cuqipy-1.2.0.post0.dev333 → cuqipy-1.2.0.post0.dev352}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev333
3
+ Version: 1.2.0.post0.dev352
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -45,6 +45,8 @@ cuqi/distribution/_smoothed_laplace.py
45
45
  cuqi/distribution/_truncated_normal.py
46
46
  cuqi/distribution/_uniform.py
47
47
  cuqi/experimental/__init__.py
48
+ cuqi/experimental/algebra/__init__.py
49
+ cuqi/experimental/algebra/_ast.py
48
50
  cuqi/experimental/mcmc/__init__.py
49
51
  cuqi/experimental/mcmc/_conjugate.py
50
52
  cuqi/experimental/mcmc/_conjugate_approx.py
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.2.0.post0.dev333
3
+ Version: 1.2.0.post0.dev352
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-11-20T12:41:18+0100",
11
+ "date": "2024-11-28T12:43:32+0100",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "e495208f39ca278465cb4819799b22c942a1b863",
15
- "version": "1.2.0.post0.dev333"
14
+ "full-revisionid": "6f800787a63fef7d6ba2c0b88b96dbc209afdbad",
15
+ "version": "1.2.0.post0.dev352"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -1,2 +1,3 @@
1
1
  """ Experimental module for testing new features and ideas. """
2
2
  from . import mcmc
3
+ from . import algebra
@@ -0,0 +1 @@
1
+ from ._ast import VariableNode
@@ -0,0 +1,325 @@
1
+ """
2
+ CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.
3
+
4
+ The AST is used to record the operations applied to variables allowing a delayed evaluation
5
+ of said operations when needed by traversing the tree with the __call__ method.
6
+
7
+ For example, the following code
8
+
9
+ x = VariableNode('x')
10
+ y = VariableNode('y')
11
+ z = 2*x + 3*y
12
+
13
+ will create the following AST:
14
+
15
+ z = AddNode(
16
+ MultiplyNode(
17
+ ValueNode(2),
18
+ VariableNode('x')
19
+ ),
20
+ MultiplyNode(
21
+ ValueNode(3),
22
+ VariableNode('y')
23
+ )
24
+ )
25
+
26
+ which can be evaluated by calling the __call__ method:
27
+
28
+ z(x=1, y=2) # returns 8
29
+
30
+ """
31
+
32
+ from abc import ABC, abstractmethod
33
+
34
+ convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
35
+ """ Converts any non-Node object to a ValueNode object. """
36
+
37
+ # ====== Base classes for the nodes ======
38
+
39
+
40
+ class Node(ABC):
41
+ """Base class for all nodes in the abstract syntax tree.
42
+
43
+ Responsible for building the AST by creating nodes that represent the operations applied to variables.
44
+
45
+ Each subclass must implement the __call__ method that will evaluate the node given the input parameters.
46
+
47
+ """
48
+
49
+ @abstractmethod
50
+ def __call__(self, **kwargs):
51
+ """Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
52
+ pass
53
+
54
+ @abstractmethod
55
+ def __repr__(self):
56
+ """String representation of the node. Used for printing the AST."""
57
+ pass
58
+
59
+ def __add__(self, other):
60
+ return AddNode(self, convert_to_node(other))
61
+
62
+ def __radd__(self, other):
63
+ return AddNode(convert_to_node(other), self)
64
+
65
+ def __sub__(self, other):
66
+ return SubtractNode(self, convert_to_node(other))
67
+
68
+ def __rsub__(self, other):
69
+ return SubtractNode(convert_to_node(other), self)
70
+
71
+ def __mul__(self, other):
72
+ return MultiplyNode(self, convert_to_node(other))
73
+
74
+ def __rmul__(self, other):
75
+ return MultiplyNode(convert_to_node(other), self)
76
+
77
+ def __truediv__(self, other):
78
+ return DivideNode(self, convert_to_node(other))
79
+
80
+ def __rtruediv__(self, other):
81
+ return DivideNode(convert_to_node(other), self)
82
+
83
+ def __pow__(self, other):
84
+ return PowerNode(self, convert_to_node(other))
85
+
86
+ def __rpow__(self, other):
87
+ return PowerNode(convert_to_node(other), self)
88
+
89
+ def __neg__(self):
90
+ return NegateNode(self)
91
+
92
+ def __abs__(self):
93
+ return AbsNode(self)
94
+
95
+ def __getitem__(self, i):
96
+ return GetItemNode(self, convert_to_node(i))
97
+
98
+ def __matmul__(self, other):
99
+ return MatMulNode(self, convert_to_node(other))
100
+
101
+ def __rmatmul__(self, other):
102
+ return MatMulNode(convert_to_node(other), self)
103
+
104
+
105
+ class UnaryNode(Node, ABC):
106
+ """Base class for all unary nodes in the abstract syntax tree.
107
+
108
+ Parameters
109
+ ----------
110
+ child : Node
111
+ The direct child node on which the unary operation is performed.
112
+
113
+ """
114
+
115
+ def __init__(self, child: Node):
116
+ self.child = child
117
+
118
+
119
+ class BinaryNode(Node, ABC):
120
+ """Base class for all binary nodes in the abstract syntax tree.
121
+
122
+ The op_symbol attribute is used for printing the operation in the __repr__ method.
123
+
124
+ Parameters
125
+ ----------
126
+ left : Node
127
+ Left child node to the binary operation.
128
+
129
+ right : Node
130
+ Right child node to the binary operation.
131
+
132
+ """
133
+
134
+ @property
135
+ @abstractmethod
136
+ def op_symbol(self):
137
+ """Symbol used to represent the operation in the __repr__ method."""
138
+ pass
139
+
140
+ def __init__(self, left: Node, right: Node):
141
+ self.left = left
142
+ self.right = right
143
+
144
+ def __repr__(self):
145
+ return f"{self.left} {self.op_symbol} {self.right}"
146
+
147
+
148
+ class BinaryNodeWithParenthesis(BinaryNode, ABC):
149
+ """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis."""
150
+
151
+ def __repr__(self):
152
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
153
+ right = (
154
+ f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
155
+ )
156
+ return f"{left} {self.op_symbol} {right}"
157
+
158
+ class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC):
159
+ """Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space."""
160
+
161
+ def __repr__(self):
162
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
163
+ right = (
164
+ f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
165
+ )
166
+ return f"{left}{self.op_symbol}{right}"
167
+
168
+
169
+ # ====== Specific implementations of the "leaf" nodes ======
170
+
171
+
172
+ class VariableNode(Node):
173
+ """Node that represents a generic variable, e.g. "x" or "y".
174
+
175
+ Parameters
176
+ ----------
177
+ name : str
178
+ Name of the variable. Used for printing and to retrieve the given input value
179
+ of the variable in the kwargs dictionary when evaluating the tree.
180
+
181
+ """
182
+
183
+ def __init__(self, name):
184
+ self.name = name
185
+
186
+ def __call__(self, **kwargs):
187
+ """Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError."""
188
+ if not self.name in kwargs:
189
+ raise KeyError(
190
+ f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression."
191
+ )
192
+ return kwargs[self.name]
193
+
194
+ def __repr__(self):
195
+ return self.name
196
+
197
+
198
+ class ValueNode(Node):
199
+ """Node that represents a constant value. The value can be any python object that is not a Node.
200
+
201
+ Parameters
202
+ ----------
203
+ value : object
204
+ The python object that represents the value of the node.
205
+
206
+ """
207
+
208
+ def __init__(self, value):
209
+ self.value = value
210
+
211
+ def __call__(self, **kwargs):
212
+ """Returns the value of the node."""
213
+ return self.value
214
+
215
+ def __repr__(self):
216
+ return str(self.value)
217
+
218
+
219
+ # ====== Specific implementations of the "internal" nodes ======
220
+
221
+
222
+ class AddNode(BinaryNode):
223
+ """Node that represents the addition operation."""
224
+
225
+ @property
226
+ def op_symbol(self):
227
+ return "+"
228
+
229
+ def __call__(self, **kwargs):
230
+ return self.left(**kwargs) + self.right(**kwargs)
231
+
232
+
233
+ class SubtractNode(BinaryNode):
234
+ """Node that represents the subtraction operation."""
235
+
236
+ @property
237
+ def op_symbol(self):
238
+ return "-"
239
+
240
+ def __call__(self, **kwargs):
241
+ return self.left(**kwargs) - self.right(**kwargs)
242
+
243
+
244
+ class MultiplyNode(BinaryNodeWithParenthesis):
245
+ """Node that represents the multiplication operation."""
246
+
247
+ @property
248
+ def op_symbol(self):
249
+ return "*"
250
+
251
+ def __call__(self, **kwargs):
252
+ return self.left(**kwargs) * self.right(**kwargs)
253
+
254
+
255
+ class DivideNode(BinaryNodeWithParenthesis):
256
+ """Node that represents the division operation."""
257
+
258
+ @property
259
+ def op_symbol(self):
260
+ return "/"
261
+
262
+ def __call__(self, **kwargs):
263
+ return self.left(**kwargs) / self.right(**kwargs)
264
+
265
+
266
+ class PowerNode(BinaryNodeWithParenthesisNoSpace):
267
+ """Node that represents the power operation."""
268
+
269
+ @property
270
+ def op_symbol(self):
271
+ return "^"
272
+
273
+ def __call__(self, **kwargs):
274
+ return self.left(**kwargs) ** self.right(**kwargs)
275
+
276
+
277
+ class GetItemNode(BinaryNode):
278
+ """Node that represents the get item operation. Here the left node is the object and the right node is the index."""
279
+
280
+ def __call__(self, **kwargs):
281
+ return self.left(**kwargs)[self.right(**kwargs)]
282
+
283
+ def __repr__(self):
284
+ left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
285
+ return f"{left}[{self.right}]"
286
+
287
+ @property
288
+ def op_symbol(self):
289
+ pass
290
+
291
+
292
+ class NegateNode(UnaryNode):
293
+ """Node that represents the arithmetic negation operation."""
294
+
295
+ def __call__(self, **kwargs):
296
+ return -self.child(**kwargs)
297
+
298
+ def __repr__(self):
299
+ child = (
300
+ f"({self.child})"
301
+ if isinstance(self.child, (BinaryNode, UnaryNode))
302
+ else str(self.child)
303
+ )
304
+ return f"-{child}"
305
+
306
+
307
+ class AbsNode(UnaryNode):
308
+ """Node that represents the absolute value operation."""
309
+
310
+ def __call__(self, **kwargs):
311
+ return abs(self.child(**kwargs))
312
+
313
+ def __repr__(self):
314
+ return f"abs({self.child})"
315
+
316
+
317
+ class MatMulNode(BinaryNodeWithParenthesis):
318
+ """Node that represents the matrix multiplication operation."""
319
+
320
+ @property
321
+ def op_symbol(self):
322
+ return "@"
323
+
324
+ def __call__(self, **kwargs):
325
+ return self.left(**kwargs) @ self.right(**kwargs)
@@ -15,8 +15,10 @@ class RestorationPrior(Distribution):
15
15
  Parameters
16
16
  ----------
17
17
  restorator : callable f(x, restoration_strength)
18
- Function f that accepts input x to be restored and returns the
19
- restored version of x and information about the restoration operation.
18
+ Function f that accepts input x to be restored and returns a two-element
19
+ tuple of the restored version of x and extra information about the
20
+ restoration operation. The second element can be of any type, including
21
+ `None` in case there is no information.
20
22
 
21
23
  restorator_kwargs : dictionary
22
24
  Dictionary containing information about the restorator.
@@ -41,8 +43,9 @@ class RestorationPrior(Distribution):
41
43
  super().__init__(**kwargs)
42
44
 
43
45
  def restore(self, x, restoration_strength):
44
- """This function allows us to restore the input x and returns the
45
- restored version of x.
46
+ """This function allows us to restore the input x with the user-supplied
47
+ restorator. Extra information about the restoration operation is stored
48
+ in the `RestorationPrior` info attribute.
46
49
 
47
50
  Parameters
48
51
  ----------
@@ -54,9 +57,18 @@ class RestorationPrior(Distribution):
54
57
  restorator is a denoiser, this parameter might correspond to the
55
58
  noise level.
56
59
  """
57
- solution, info = self.restorator(x, restoration_strength=restoration_strength,
60
+ restorator_return = self.restorator(x, restoration_strength=restoration_strength,
58
61
  **self.restorator_kwargs)
59
- self.info = info
62
+
63
+ if type(restorator_return) == tuple and len(restorator_return) == 2:
64
+ solution, self.info = restorator_return
65
+ else:
66
+ raise ValueError("Unsupported return type from the user-supplied restorator function. "+
67
+ "Please ensure that the restorator function returns a two-element tuple with the "+
68
+ "restored solution as the first element and additional information about the "+
69
+ "restoration as the second element. The second element can be of any type, "+
70
+ "including `None` in case there is no particular information.")
71
+
60
72
  return solution
61
73
 
62
74
  def logpdf(self, x):
@@ -28,7 +28,7 @@ def test_RegularizedGaussian_guarding_statements():
28
28
 
29
29
  with pytest.raises(ValueError, match="Projector should take 1 argument"):
30
30
  cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, projector=lambda s,z: s)
31
-
31
+
32
32
  def test_creating_restorator():
33
33
  """ Test creating the object from restorator class."""
34
34
 
@@ -38,6 +38,32 @@ def test_creating_restorator():
38
38
  assert np.allclose(restorator.restore(np.ones(4), 0.1), np.ones(4))
39
39
  assert restorator.info == True
40
40
 
41
+ def test_handling_invalid_restorator():
42
+ """ Test handling invalid restorator."""
43
+ # Invalid return type 1: None
44
+ def func_1(x, restoration_strength=0.1):
45
+ return
46
+ restore_prior_1 = cuqi.implicitprior.RestorationPrior(func_1)
47
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
48
+ restore_prior_1.restore(np.ones(4), 0.1)
49
+ # Invalid return type 2: one parameter
50
+ def func_2(x, restoration_strength=0.1):
51
+ return x
52
+ restore_prior_2 = cuqi.implicitprior.RestorationPrior(func_2)
53
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
54
+ restore_prior_2.restore(np.ones(4), 0.1)
55
+ # Invalid return type 3: tuple with 3 elements
56
+ def func_3(x, restoration_strength=0.1):
57
+ return x, None, False
58
+ restore_prior_3 = cuqi.implicitprior.RestorationPrior(func_3)
59
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
60
+ restore_prior_3.restore(np.ones(4), 0.1)
61
+ # Invalid return type 4: list with 2 elements
62
+ def func_4(x, restoration_strength=0.1):
63
+ return [x, None]
64
+ restore_prior_4 = cuqi.implicitprior.RestorationPrior(func_4)
65
+ with pytest.raises(ValueError, match=r"Unsupported return type .*"):
66
+ restore_prior_4.restore(np.ones(4), 0.1)
41
67
 
42
68
  def test_creating_restorator_with_potential():
43
69
  """ Test creating the object from restorator class with a potential."""