CUQIpy 1.2.0.post0.dev314__tar.gz → 1.2.0.post0.dev342__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/PKG-INFO +1 -1
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/SOURCES.txt +2 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/PKG-INFO +1 -1
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/_version.py +3 -3
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/__init__.py +1 -0
- cuqipy-1.2.0.post0.dev342/cuqi/experimental/algebra/__init__.py +1 -0
- cuqipy-1.2.0.post0.dev342/cuqi/experimental/algebra/_ast.py +325 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_rto.py +32 -8
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_regularizedGMRF.py +1 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_regularizedGaussian.py +82 -25
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/solver/_solver.py +1 -1
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_implicit_priors.py +35 -1
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/dependency_links.txt +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/requires.txt +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/top_level.txt +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/LICENSE +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/README.md +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/_messages.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/array/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/array/_array.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/config.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/_data.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/astronaut.npz +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/camera.npz +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/cat.npz +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/cookie.png +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/data/satellite.mat +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/density/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/density/_density.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/diagnostics.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_beta.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_cauchy.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_cmrf.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_custom.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_distribution.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_gamma.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_gaussian.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_gmrf.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_inverse_gamma.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_joint_distribution.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_laplace.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_lmrf.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_lognormal.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_normal.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_posterior.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_smoothed_laplace.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_truncated_normal.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_uniform.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_conjugate.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_cwmh.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_direct.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_gibbs.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_hmc.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_mh.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_pcn.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_sampler.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_utilities.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/geometry/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/geometry/_geometry.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_restorator.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/likelihood/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/likelihood/_likelihood.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/model/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/model/_model.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/operator/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/operator/_operator.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/pde/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/pde/_pde.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/problem/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/problem/_problem.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_conjugate.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_conjugate_approx.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_cwmh.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_gibbs.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_hmc.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_langevin_algorithm.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_laplace_approximation.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_mh.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_pcn.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_rto.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_sampler.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/samples/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/samples/_samples.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/solver/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/testproblem/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/testproblem/_testproblem.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/utilities/__init__.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/utilities/_get_python_variable_name.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/utilities/_utilities.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/pyproject.toml +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/requirements.txt +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/setup.cfg +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/setup.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_MRFs.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_abstract_distribution_density.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_bayesian_inversion.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_density.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_distribution.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_distributions_shape.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_geometry.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_joint_distribution.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_likelihood.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_model.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_pde.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_posterior.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_problem.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_sampler.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_samples.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_solver.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_testproblem.py +0 -0
- {cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_utilities.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.2.0.post0.
|
|
3
|
+
Version: 1.2.0.post0.dev342
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -45,6 +45,8 @@ cuqi/distribution/_smoothed_laplace.py
|
|
|
45
45
|
cuqi/distribution/_truncated_normal.py
|
|
46
46
|
cuqi/distribution/_uniform.py
|
|
47
47
|
cuqi/experimental/__init__.py
|
|
48
|
+
cuqi/experimental/algebra/__init__.py
|
|
49
|
+
cuqi/experimental/algebra/_ast.py
|
|
48
50
|
cuqi/experimental/mcmc/__init__.py
|
|
49
51
|
cuqi/experimental/mcmc/_conjugate.py
|
|
50
52
|
cuqi/experimental/mcmc/_conjugate_approx.py
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.2.0.post0.
|
|
3
|
+
Version: 1.2.0.post0.dev342
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-11-
|
|
11
|
+
"date": "2024-11-25T09:18:51+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.2.0.post0.
|
|
14
|
+
"full-revisionid": "63395d14f6c3f964633b20200ada13b8a213da20",
|
|
15
|
+
"version": "1.2.0.post0.dev342"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._ast import VariableNode
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CUQIpy specific implementation of an abstract syntax tree (AST) for algebra on variables.
|
|
3
|
+
|
|
4
|
+
The AST is used to record the operations applied to variables allowing a delayed evaluation
|
|
5
|
+
of said operations when needed by traversing the tree with the __call__ method.
|
|
6
|
+
|
|
7
|
+
For example, the following code
|
|
8
|
+
|
|
9
|
+
x = VariableNode('x')
|
|
10
|
+
y = VariableNode('y')
|
|
11
|
+
z = 2*x + 3*y
|
|
12
|
+
|
|
13
|
+
will create the following AST:
|
|
14
|
+
|
|
15
|
+
z = AddNode(
|
|
16
|
+
MultiplyNode(
|
|
17
|
+
ValueNode(2),
|
|
18
|
+
VariableNode('x')
|
|
19
|
+
),
|
|
20
|
+
MultiplyNode(
|
|
21
|
+
ValueNode(3),
|
|
22
|
+
VariableNode('y')
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
which can be evaluated by calling the __call__ method:
|
|
27
|
+
|
|
28
|
+
z(x=1, y=2) # returns 8
|
|
29
|
+
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from abc import ABC, abstractmethod
|
|
33
|
+
|
|
34
|
+
convert_to_node = lambda x: x if isinstance(x, Node) else ValueNode(x)
|
|
35
|
+
""" Converts any non-Node object to a ValueNode object. """
|
|
36
|
+
|
|
37
|
+
# ====== Base classes for the nodes ======
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Node(ABC):
|
|
41
|
+
"""Base class for all nodes in the abstract syntax tree.
|
|
42
|
+
|
|
43
|
+
Responsible for building the AST by creating nodes that represent the operations applied to variables.
|
|
44
|
+
|
|
45
|
+
Each subclass must implement the __call__ method that will evaluate the node given the input parameters.
|
|
46
|
+
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def __call__(self, **kwargs):
|
|
51
|
+
"""Evaluate node at a given parameter value. This will traverse the sub-tree originated at this node and evaluate it given the recorded operations."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def __repr__(self):
|
|
56
|
+
"""String representation of the node. Used for printing the AST."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
def __add__(self, other):
|
|
60
|
+
return AddNode(self, convert_to_node(other))
|
|
61
|
+
|
|
62
|
+
def __radd__(self, other):
|
|
63
|
+
return AddNode(convert_to_node(other), self)
|
|
64
|
+
|
|
65
|
+
def __sub__(self, other):
|
|
66
|
+
return SubtractNode(self, convert_to_node(other))
|
|
67
|
+
|
|
68
|
+
def __rsub__(self, other):
|
|
69
|
+
return SubtractNode(convert_to_node(other), self)
|
|
70
|
+
|
|
71
|
+
def __mul__(self, other):
|
|
72
|
+
return MultiplyNode(self, convert_to_node(other))
|
|
73
|
+
|
|
74
|
+
def __rmul__(self, other):
|
|
75
|
+
return MultiplyNode(convert_to_node(other), self)
|
|
76
|
+
|
|
77
|
+
def __truediv__(self, other):
|
|
78
|
+
return DivideNode(self, convert_to_node(other))
|
|
79
|
+
|
|
80
|
+
def __rtruediv__(self, other):
|
|
81
|
+
return DivideNode(convert_to_node(other), self)
|
|
82
|
+
|
|
83
|
+
def __pow__(self, other):
|
|
84
|
+
return PowerNode(self, convert_to_node(other))
|
|
85
|
+
|
|
86
|
+
def __rpow__(self, other):
|
|
87
|
+
return PowerNode(convert_to_node(other), self)
|
|
88
|
+
|
|
89
|
+
def __neg__(self):
|
|
90
|
+
return NegateNode(self)
|
|
91
|
+
|
|
92
|
+
def __abs__(self):
|
|
93
|
+
return AbsNode(self)
|
|
94
|
+
|
|
95
|
+
def __getitem__(self, i):
|
|
96
|
+
return GetItemNode(self, convert_to_node(i))
|
|
97
|
+
|
|
98
|
+
def __matmul__(self, other):
|
|
99
|
+
return MatMulNode(self, convert_to_node(other))
|
|
100
|
+
|
|
101
|
+
def __rmatmul__(self, other):
|
|
102
|
+
return MatMulNode(convert_to_node(other), self)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class UnaryNode(Node, ABC):
|
|
106
|
+
"""Base class for all unary nodes in the abstract syntax tree.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
child : Node
|
|
111
|
+
The direct child node on which the unary operation is performed.
|
|
112
|
+
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, child: Node):
|
|
116
|
+
self.child = child
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class BinaryNode(Node, ABC):
|
|
120
|
+
"""Base class for all binary nodes in the abstract syntax tree.
|
|
121
|
+
|
|
122
|
+
The op_symbol attribute is used for printing the operation in the __repr__ method.
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
left : Node
|
|
127
|
+
Left child node to the binary operation.
|
|
128
|
+
|
|
129
|
+
right : Node
|
|
130
|
+
Right child node to the binary operation.
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
@abstractmethod
|
|
136
|
+
def op_symbol(self):
|
|
137
|
+
"""Symbol used to represent the operation in the __repr__ method."""
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
def __init__(self, left: Node, right: Node):
|
|
141
|
+
self.left = left
|
|
142
|
+
self.right = right
|
|
143
|
+
|
|
144
|
+
def __repr__(self):
|
|
145
|
+
return f"{self.left} {self.op_symbol} {self.right}"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class BinaryNodeWithParenthesis(BinaryNode, ABC):
|
|
149
|
+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis."""
|
|
150
|
+
|
|
151
|
+
def __repr__(self):
|
|
152
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
153
|
+
right = (
|
|
154
|
+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
|
|
155
|
+
)
|
|
156
|
+
return f"{left} {self.op_symbol} {right}"
|
|
157
|
+
|
|
158
|
+
class BinaryNodeWithParenthesisNoSpace(BinaryNode, ABC):
|
|
159
|
+
"""Base class for all binary nodes in the abstract syntax tree that should be printed with parenthesis but no space."""
|
|
160
|
+
|
|
161
|
+
def __repr__(self):
|
|
162
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
163
|
+
right = (
|
|
164
|
+
f"({self.right})" if isinstance(self.right, BinaryNode) else str(self.right)
|
|
165
|
+
)
|
|
166
|
+
return f"{left}{self.op_symbol}{right}"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# ====== Specific implementations of the "leaf" nodes ======
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class VariableNode(Node):
|
|
173
|
+
"""Node that represents a generic variable, e.g. "x" or "y".
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
name : str
|
|
178
|
+
Name of the variable. Used for printing and to retrieve the given input value
|
|
179
|
+
of the variable in the kwargs dictionary when evaluating the tree.
|
|
180
|
+
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def __init__(self, name):
|
|
184
|
+
self.name = name
|
|
185
|
+
|
|
186
|
+
def __call__(self, **kwargs):
|
|
187
|
+
"""Retrieves the value of the variable from the passed kwargs. If no value is found, it raises a KeyError."""
|
|
188
|
+
if not self.name in kwargs:
|
|
189
|
+
raise KeyError(
|
|
190
|
+
f"Variable '{self.name}' not found in the given input parameters. Unable to evaluate the expression."
|
|
191
|
+
)
|
|
192
|
+
return kwargs[self.name]
|
|
193
|
+
|
|
194
|
+
def __repr__(self):
|
|
195
|
+
return self.name
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class ValueNode(Node):
|
|
199
|
+
"""Node that represents a constant value. The value can be any python object that is not a Node.
|
|
200
|
+
|
|
201
|
+
Parameters
|
|
202
|
+
----------
|
|
203
|
+
value : object
|
|
204
|
+
The python object that represents the value of the node.
|
|
205
|
+
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def __init__(self, value):
|
|
209
|
+
self.value = value
|
|
210
|
+
|
|
211
|
+
def __call__(self, **kwargs):
|
|
212
|
+
"""Returns the value of the node."""
|
|
213
|
+
return self.value
|
|
214
|
+
|
|
215
|
+
def __repr__(self):
|
|
216
|
+
return str(self.value)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
# ====== Specific implementations of the "internal" nodes ======
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class AddNode(BinaryNode):
|
|
223
|
+
"""Node that represents the addition operation."""
|
|
224
|
+
|
|
225
|
+
@property
|
|
226
|
+
def op_symbol(self):
|
|
227
|
+
return "+"
|
|
228
|
+
|
|
229
|
+
def __call__(self, **kwargs):
|
|
230
|
+
return self.left(**kwargs) + self.right(**kwargs)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
class SubtractNode(BinaryNode):
|
|
234
|
+
"""Node that represents the subtraction operation."""
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def op_symbol(self):
|
|
238
|
+
return "-"
|
|
239
|
+
|
|
240
|
+
def __call__(self, **kwargs):
|
|
241
|
+
return self.left(**kwargs) - self.right(**kwargs)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
class MultiplyNode(BinaryNodeWithParenthesis):
|
|
245
|
+
"""Node that represents the multiplication operation."""
|
|
246
|
+
|
|
247
|
+
@property
|
|
248
|
+
def op_symbol(self):
|
|
249
|
+
return "*"
|
|
250
|
+
|
|
251
|
+
def __call__(self, **kwargs):
|
|
252
|
+
return self.left(**kwargs) * self.right(**kwargs)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class DivideNode(BinaryNodeWithParenthesis):
|
|
256
|
+
"""Node that represents the division operation."""
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def op_symbol(self):
|
|
260
|
+
return "/"
|
|
261
|
+
|
|
262
|
+
def __call__(self, **kwargs):
|
|
263
|
+
return self.left(**kwargs) / self.right(**kwargs)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class PowerNode(BinaryNodeWithParenthesisNoSpace):
|
|
267
|
+
"""Node that represents the power operation."""
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def op_symbol(self):
|
|
271
|
+
return "^"
|
|
272
|
+
|
|
273
|
+
def __call__(self, **kwargs):
|
|
274
|
+
return self.left(**kwargs) ** self.right(**kwargs)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class GetItemNode(BinaryNode):
|
|
278
|
+
"""Node that represents the get item operation. Here the left node is the object and the right node is the index."""
|
|
279
|
+
|
|
280
|
+
def __call__(self, **kwargs):
|
|
281
|
+
return self.left(**kwargs)[self.right(**kwargs)]
|
|
282
|
+
|
|
283
|
+
def __repr__(self):
|
|
284
|
+
left = f"({self.left})" if isinstance(self.left, BinaryNode) else str(self.left)
|
|
285
|
+
return f"{left}[{self.right}]"
|
|
286
|
+
|
|
287
|
+
@property
|
|
288
|
+
def op_symbol(self):
|
|
289
|
+
pass
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class NegateNode(UnaryNode):
|
|
293
|
+
"""Node that represents the arithmetic negation operation."""
|
|
294
|
+
|
|
295
|
+
def __call__(self, **kwargs):
|
|
296
|
+
return -self.child(**kwargs)
|
|
297
|
+
|
|
298
|
+
def __repr__(self):
|
|
299
|
+
child = (
|
|
300
|
+
f"({self.child})"
|
|
301
|
+
if isinstance(self.child, (BinaryNode, UnaryNode))
|
|
302
|
+
else str(self.child)
|
|
303
|
+
)
|
|
304
|
+
return f"-{child}"
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class AbsNode(UnaryNode):
|
|
308
|
+
"""Node that represents the absolute value operation."""
|
|
309
|
+
|
|
310
|
+
def __call__(self, **kwargs):
|
|
311
|
+
return abs(self.child(**kwargs))
|
|
312
|
+
|
|
313
|
+
def __repr__(self):
|
|
314
|
+
return f"abs({self.child})"
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class MatMulNode(BinaryNodeWithParenthesis):
|
|
318
|
+
"""Node that represents the matrix multiplication operation."""
|
|
319
|
+
|
|
320
|
+
@property
|
|
321
|
+
def op_symbol(self):
|
|
322
|
+
return "@"
|
|
323
|
+
|
|
324
|
+
def __call__(self, **kwargs):
|
|
325
|
+
return self.left(**kwargs) @ self.right(**kwargs)
|
|
@@ -3,7 +3,7 @@ from scipy.linalg.interpolative import estimate_spectral_norm
|
|
|
3
3
|
from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
|
|
4
4
|
import numpy as np
|
|
5
5
|
import cuqi
|
|
6
|
-
from cuqi.solver import CGLS, FISTA
|
|
6
|
+
from cuqi.solver import CGLS, FISTA, ADMM
|
|
7
7
|
from cuqi.experimental.mcmc import Sampler
|
|
8
8
|
|
|
9
9
|
|
|
@@ -161,6 +161,13 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
161
161
|
Regularized Linear RTO (Randomize-Then-Optimize) sampler.
|
|
162
162
|
|
|
163
163
|
Samples posterior related to the inverse problem with Gaussian likelihood and implicit Gaussian prior, and where the forward model is Linear.
|
|
164
|
+
The sampler works by repeatedly solving regularized linear least squares problems for perturbed data.
|
|
165
|
+
The solver for these optimization problems is chosen based on how the regularized is provided in the implicit Gaussian prior.
|
|
166
|
+
Currently we use the following solvers:
|
|
167
|
+
FISTA: [1] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding algorithm for linear inverse problems." SIAM journal on imaging sciences 2.1 (2009): 183-202.
|
|
168
|
+
Used when prior.proximal is callable.
|
|
169
|
+
ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
|
|
170
|
+
Used when prior.proximal is a list of penalty terms.
|
|
164
171
|
|
|
165
172
|
Parameters
|
|
166
173
|
------------
|
|
@@ -171,12 +178,19 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
171
178
|
Initial point for the sampler. *Optional*.
|
|
172
179
|
|
|
173
180
|
maxit : int
|
|
174
|
-
Maximum number of iterations of the inner FISTA solver. *Optional*.
|
|
181
|
+
Maximum number of iterations of the inner FISTA/ADMM solver. *Optional*.
|
|
182
|
+
|
|
183
|
+
inner_max_it : int
|
|
184
|
+
Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
|
|
175
185
|
|
|
176
186
|
stepsize : string or float
|
|
177
187
|
If stepsize is a string and equals either "automatic", then the stepsize is automatically estimated based on the spectral norm.
|
|
178
188
|
If stepsize is a float, then this stepsize is used.
|
|
179
189
|
|
|
190
|
+
penalty_parameter : int
|
|
191
|
+
Penalty parameter of the inner ADMM solver. *Optional*.
|
|
192
|
+
See [2] or `cuqi.solver.ADMM`
|
|
193
|
+
|
|
180
194
|
abstol : float
|
|
181
195
|
Absolute tolerance of the inner FISTA solver. *Optional*.
|
|
182
196
|
|
|
@@ -190,7 +204,7 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
190
204
|
An example is shown in demos/demo31_callback.py.
|
|
191
205
|
|
|
192
206
|
"""
|
|
193
|
-
def __init__(self, target=None, initial_point=None, maxit=100, stepsize="automatic", abstol=1e-10, adaptive=True, **kwargs):
|
|
207
|
+
def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, **kwargs):
|
|
194
208
|
|
|
195
209
|
super().__init__(target=target, initial_point=initial_point, **kwargs)
|
|
196
210
|
|
|
@@ -199,10 +213,13 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
199
213
|
self.abstol = abstol
|
|
200
214
|
self.adaptive = adaptive
|
|
201
215
|
self.maxit = maxit
|
|
216
|
+
self.inner_max_it = inner_max_it
|
|
217
|
+
self.penalty_parameter = penalty_parameter
|
|
202
218
|
|
|
203
219
|
def _initialize(self):
|
|
204
220
|
super()._initialize()
|
|
205
|
-
self.
|
|
221
|
+
if self._inner_solver == "FISTA":
|
|
222
|
+
self._stepsize = self._choose_stepsize()
|
|
206
223
|
|
|
207
224
|
@property
|
|
208
225
|
def proximal(self):
|
|
@@ -212,8 +229,7 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
212
229
|
super().validate_target()
|
|
213
230
|
if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
|
|
214
231
|
raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
|
|
215
|
-
if
|
|
216
|
-
raise TypeError("Proximal needs to be callable")
|
|
232
|
+
self._inner_solver = "FISTA" if callable(self.proximal) else "ADMM"
|
|
217
233
|
|
|
218
234
|
def _choose_stepsize(self):
|
|
219
235
|
if isinstance(self.stepsize, str):
|
|
@@ -237,8 +253,16 @@ class RegularizedLinearRTO(LinearRTO):
|
|
|
237
253
|
|
|
238
254
|
def step(self):
|
|
239
255
|
y = self.b_tild + np.random.randn(len(self.b_tild))
|
|
240
|
-
|
|
241
|
-
|
|
256
|
+
|
|
257
|
+
if self._inner_solver == "FISTA":
|
|
258
|
+
sim = FISTA(self.M, y, self.proximal,
|
|
259
|
+
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
260
|
+
elif self._inner_solver == "ADMM":
|
|
261
|
+
sim = ADMM(self.M, y, self.proximal,
|
|
262
|
+
self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
|
|
263
|
+
else:
|
|
264
|
+
raise ValueError("Choice of solver not supported.")
|
|
265
|
+
|
|
242
266
|
self.current_point, _ = sim.solve()
|
|
243
267
|
acc = 1
|
|
244
268
|
return acc
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_regularizedGMRF.py
RENAMED
|
@@ -63,6 +63,7 @@ class RegularizedGMRF(RegularizedGaussian):
|
|
|
63
63
|
|
|
64
64
|
# Underlying explicit GMRF
|
|
65
65
|
self._gaussian = GMRF(mean, prec, bc_type=bc_type, order=order, **kwargs)
|
|
66
|
+
kwargs.pop("geometry", None)
|
|
66
67
|
|
|
67
68
|
# Init from abstract distribution class
|
|
68
69
|
super(Distribution, self).__init__(**kwargs)
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/implicitprior/_regularizedGaussian.py
RENAMED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from cuqi.utilities import get_non_default_args
|
|
2
2
|
from cuqi.distribution import Distribution, Gaussian
|
|
3
3
|
from cuqi.solver import ProjectNonnegative, ProjectBox, ProximalL1
|
|
4
|
+
from cuqi.geometry import Continuous1D, Continuous2D, Image2D
|
|
5
|
+
from cuqi.operator import FirstOrderFiniteDifference
|
|
4
6
|
|
|
5
7
|
import numpy as np
|
|
6
8
|
|
|
@@ -39,17 +41,22 @@ class RegularizedGaussian(Distribution):
|
|
|
39
41
|
sqrtprec
|
|
40
42
|
See :class:`~cuqi.distribution.Gaussian` for details.
|
|
41
43
|
|
|
42
|
-
proximal : callable f(x, scale) or None
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
44
|
+
proximal : callable f(x, scale), list of tuples (callable proximal operator of f_i, linear operator L_i) or None
|
|
45
|
+
If callable:
|
|
46
|
+
Euclidean proximal operator f of the regularization function g, that is, a solver for the optimization problem
|
|
47
|
+
min_z 0.5||x-z||_2^2+scale*g(x).
|
|
48
|
+
If list of tuples (callable proximal operator of f_i, linear operator L_i):
|
|
49
|
+
Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
|
|
50
|
+
The corresponding regularization takes the form
|
|
51
|
+
sum_i f_i(L_i x),
|
|
52
|
+
where the sum ranges from 1 to an arbitrary n.
|
|
46
53
|
|
|
47
54
|
projector : callable f(x) or None
|
|
48
55
|
Euclidean projection onto the constraint C, that is, a solver for the optimization problem
|
|
49
56
|
min_(z in C) 0.5||x-z||_2^2.
|
|
50
57
|
|
|
51
58
|
constraint : string or None
|
|
52
|
-
Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
|
|
59
|
+
Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
|
|
53
60
|
For "box", the following additional parameters can be passed:
|
|
54
61
|
lower_bound : array_like or None
|
|
55
62
|
Lower bound of box, defaults to zero
|
|
@@ -57,10 +64,10 @@ class RegularizedGaussian(Distribution):
|
|
|
57
64
|
Upper bound of box, defaults to one
|
|
58
65
|
|
|
59
66
|
regularization : string or None
|
|
60
|
-
Preset regularization. Can be set to "l1". Required for use in Gibbs in future update.
|
|
61
|
-
For "l1", the following additional parameters can be passed:
|
|
67
|
+
Preset regularization that generate the corresponding proximal parameter. Can be set to "l1" or 'tv'. Required for use in Gibbs in future update.
|
|
68
|
+
For "l1" or "tv", the following additional parameters can be passed:
|
|
62
69
|
strength : scalar
|
|
63
|
-
Regularization parameter, i.e., strength*||
|
|
70
|
+
Regularization parameter, i.e., strength*||Lx||_1, defaults to one
|
|
64
71
|
|
|
65
72
|
"""
|
|
66
73
|
|
|
@@ -75,6 +82,7 @@ class RegularizedGaussian(Distribution):
|
|
|
75
82
|
|
|
76
83
|
# We init the underlying Gaussian first for geometry and dimensionality handling
|
|
77
84
|
self._gaussian = Gaussian(mean=mean, cov=cov, prec=prec, sqrtcov=sqrtcov, sqrtprec=sqrtprec, **kwargs)
|
|
85
|
+
kwargs.pop("geometry", None)
|
|
78
86
|
|
|
79
87
|
# Init from abstract distribution class
|
|
80
88
|
super().__init__(**kwargs)
|
|
@@ -88,12 +96,6 @@ class RegularizedGaussian(Distribution):
|
|
|
88
96
|
if (proximal is not None) + (projector is not None) + (constraint is not None) + (regularization is not None) != 1:
|
|
89
97
|
raise ValueError("Precisely one of proximal, projector, constraint or regularization needs to be provided.")
|
|
90
98
|
|
|
91
|
-
if proximal is not None:
|
|
92
|
-
if not callable(proximal):
|
|
93
|
-
raise ValueError("Proximal needs to be callable.")
|
|
94
|
-
if len(get_non_default_args(proximal)) != 2:
|
|
95
|
-
raise ValueError("Proximal should take 2 arguments.")
|
|
96
|
-
|
|
97
99
|
if projector is not None:
|
|
98
100
|
if not callable(projector):
|
|
99
101
|
raise ValueError("Projector needs to be callable.")
|
|
@@ -104,7 +106,8 @@ class RegularizedGaussian(Distribution):
|
|
|
104
106
|
self._preset = None
|
|
105
107
|
|
|
106
108
|
if proximal is not None:
|
|
107
|
-
|
|
109
|
+
# No need to generate the proximal and associated information
|
|
110
|
+
self.proximal = proximal
|
|
108
111
|
elif projector is not None:
|
|
109
112
|
self._proximal = lambda z, gamma: projector(z)
|
|
110
113
|
elif (isinstance(constraint, str) and constraint.lower() == "nonnegativity"):
|
|
@@ -113,15 +116,48 @@ class RegularizedGaussian(Distribution):
|
|
|
113
116
|
elif (isinstance(constraint, str) and constraint.lower() == "box"):
|
|
114
117
|
lower = optional_regularization_parameters["lower_bound"]
|
|
115
118
|
upper = optional_regularization_parameters["upper_bound"]
|
|
116
|
-
self._proximal = lambda z,
|
|
119
|
+
self._proximal = lambda z, _: ProjectBox(z, lower, upper)
|
|
117
120
|
self._preset = "box" # Not supported in Gibbs
|
|
118
121
|
elif (isinstance(regularization, str) and regularization.lower() in ["l1"]):
|
|
119
|
-
|
|
120
|
-
self._proximal = lambda z, gamma: ProximalL1(z, gamma*
|
|
122
|
+
self._strength = optional_regularization_parameters["strength"]
|
|
123
|
+
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
|
|
121
124
|
self._preset = "l1"
|
|
125
|
+
elif (isinstance(regularization, str) and regularization.lower() in ["tv"]):
|
|
126
|
+
self._strength = optional_regularization_parameters["strength"]
|
|
127
|
+
if isinstance(self.geometry, (Continuous1D, Continuous2D, Image2D)):
|
|
128
|
+
self._transformation = FirstOrderFiniteDifference(self.geometry.fun_shape, bc_type='neumann')
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError("Geometry not supported for total variation")
|
|
131
|
+
|
|
132
|
+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
|
|
133
|
+
self._regularization_oper = self._transformation
|
|
134
|
+
|
|
135
|
+
self._proximal = [(self._regularization_prox, self._regularization_oper)]
|
|
136
|
+
self._preset = "tv"
|
|
122
137
|
else:
|
|
123
138
|
raise ValueError("Regularization not supported")
|
|
124
139
|
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def transformation(self):
|
|
143
|
+
return self._transformation
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def strength(self):
|
|
147
|
+
return self._strength
|
|
148
|
+
|
|
149
|
+
@strength.setter
|
|
150
|
+
def strength(self, value):
|
|
151
|
+
if self._preset not in self.regularization_options():
|
|
152
|
+
raise TypeError("Strength is only used when the regularization is set to l1 or TV.")
|
|
153
|
+
|
|
154
|
+
self._strength = value
|
|
155
|
+
if self._preset == "tv":
|
|
156
|
+
self._regularization_prox = lambda z, gamma: ProximalL1(z, gamma*self._strength)
|
|
157
|
+
self._proximal = [(self._regularization_prox, self._regularization_oper)]
|
|
158
|
+
elif self._preset == "l1":
|
|
159
|
+
self._proximal = lambda z, gamma: ProximalL1(z, gamma*self._strength)
|
|
160
|
+
|
|
125
161
|
# This is a getter only attribute for the underlying Gaussian
|
|
126
162
|
# It also ensures that the name of the underlying Gaussian
|
|
127
163
|
# matches the name of the implicit regularized Gaussian
|
|
@@ -135,6 +171,25 @@ class RegularizedGaussian(Distribution):
|
|
|
135
171
|
def proximal(self):
|
|
136
172
|
return self._proximal
|
|
137
173
|
|
|
174
|
+
@proximal.setter
|
|
175
|
+
def proximal(self, value):
|
|
176
|
+
if callable(value):
|
|
177
|
+
if len(get_non_default_args(value)) != 2:
|
|
178
|
+
raise ValueError("Proximal should take 2 arguments.")
|
|
179
|
+
elif isinstance(value, list):
|
|
180
|
+
for (prox, op) in value:
|
|
181
|
+
if len(get_non_default_args(prox)) != 2:
|
|
182
|
+
raise ValueError("Proximal should take 2 arguments.")
|
|
183
|
+
if op.shape[1] != self.geometry.par_dim:
|
|
184
|
+
raise ValueError("Incorrect shape of linear operator in proximal list.")
|
|
185
|
+
else:
|
|
186
|
+
raise ValueError("Proximal needs to be callable or a list. See documentation.")
|
|
187
|
+
|
|
188
|
+
self._proximal = value
|
|
189
|
+
|
|
190
|
+
# For all the presets, self._proximal is set directly,
|
|
191
|
+
self._preset = None
|
|
192
|
+
|
|
138
193
|
@property
|
|
139
194
|
def preset(self):
|
|
140
195
|
return self._preset
|
|
@@ -154,7 +209,7 @@ class RegularizedGaussian(Distribution):
|
|
|
154
209
|
|
|
155
210
|
@staticmethod
|
|
156
211
|
def regularization_options():
|
|
157
|
-
return ["l1"]
|
|
212
|
+
return ["l1", "tv"]
|
|
158
213
|
|
|
159
214
|
|
|
160
215
|
# --- Defer behavior of the underlying Gaussian --- #
|
|
@@ -206,16 +261,18 @@ class RegularizedGaussian(Distribution):
|
|
|
206
261
|
def sqrtcov(self, value):
|
|
207
262
|
self.gaussian.sqrtcov = value
|
|
208
263
|
|
|
209
|
-
def get_conditioning_variables(self):
|
|
210
|
-
return self.gaussian.get_conditioning_variables()
|
|
211
|
-
|
|
212
264
|
def get_mutable_variables(self):
|
|
213
|
-
|
|
265
|
+
mutable_vars = self.gaussian.get_mutable_variables().copy()
|
|
266
|
+
if self.preset in self.regularization_options():
|
|
267
|
+
mutable_vars += ["strength"]
|
|
268
|
+
return mutable_vars
|
|
214
269
|
|
|
215
270
|
# Overwrite the condition method such that the underlying Gaussian is conditioned in general, except when conditioning on self.name
|
|
216
271
|
# which means we convert Distribution to Likelihood or EvaluatedDensity.
|
|
217
272
|
def _condition(self, *args, **kwargs):
|
|
218
|
-
|
|
273
|
+
if self.preset in self.regularization_options():
|
|
274
|
+
return super()._condition(*args, **kwargs)
|
|
275
|
+
|
|
219
276
|
# Handle positional arguments (similar code as in Distribution._condition)
|
|
220
277
|
cond_vars = self.get_conditioning_variables()
|
|
221
278
|
kwargs = self._parse_args_add_to_kwargs(cond_vars, *args, **kwargs)
|
|
@@ -275,7 +332,7 @@ class ConstrainedGaussian(RegularizedGaussian):
|
|
|
275
332
|
min_(z in C) 0.5||x-z||_2^2.
|
|
276
333
|
|
|
277
334
|
constraint : string or None
|
|
278
|
-
Preset constraints. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
|
|
335
|
+
Preset constraints that generate the corresponding proximal parameter. Can be set to "nonnegativity" and "box". Required for use in Gibbs.
|
|
279
336
|
For "box", the following additional parameters can be passed:
|
|
280
337
|
lower_bound : array_like or None
|
|
281
338
|
Lower bound of box, defaults to zero
|
|
@@ -669,7 +669,7 @@ class ADMM(object):
|
|
|
669
669
|
- flag=2 indicates multiplication of the transpose of A with vector x, that is A.T @ x.
|
|
670
670
|
b : ndarray.
|
|
671
671
|
penalty_terms : List of tuples (callable proximal operator of f_i, linear operator L_i)
|
|
672
|
-
Each callable proximal operator f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
|
|
672
|
+
Each callable proximal operator of f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
|
|
673
673
|
x0 : ndarray. Initial guess.
|
|
674
674
|
penalty_parameter : Trade-off between linear least squares and regularization term in the solver iterates. Denoted as "rho" in [1].
|
|
675
675
|
maxit : The maximum number of iterations.
|
|
@@ -16,7 +16,7 @@ def test_RegularizedGaussian_guarding_statements():
|
|
|
16
16
|
cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, proximal=lambda s,z: s, constraint="nonnegativity")
|
|
17
17
|
|
|
18
18
|
# Proximal
|
|
19
|
-
with pytest.raises(ValueError, match="Proximal needs to be callable"):
|
|
19
|
+
with pytest.raises(ValueError, match="Proximal needs to be callable or a list. See documentation."):
|
|
20
20
|
cuqi.implicitprior.RegularizedGaussian(np.zeros(5), 1, proximal=1)
|
|
21
21
|
|
|
22
22
|
with pytest.raises(ValueError, match="Proximal should take 2 arguments"):
|
|
@@ -104,3 +104,37 @@ def test_RegularizedUnboundedUniform_is_RegularizedGaussian():
|
|
|
104
104
|
x = cuqi.implicitprior.RegularizedUnboundedUniform(cuqi.geometry.Continuous1D(5), regularization="l1", strength = 5.0)
|
|
105
105
|
|
|
106
106
|
assert np.allclose(x.gaussian.sqrtprec, 0.0)
|
|
107
|
+
|
|
108
|
+
def test_RegularizedGaussian_conditioning_constrained():
|
|
109
|
+
""" Test that conditioning the implicit regularized Gaussian works as expected """
|
|
110
|
+
|
|
111
|
+
x = cuqi.implicitprior.RegularizedGMRF(lambda a:a*np.ones(2**2),
|
|
112
|
+
prec = lambda b:5*b,
|
|
113
|
+
constraint = "nonnegativity",
|
|
114
|
+
geometry = cuqi.geometry.Image2D((2,2)))
|
|
115
|
+
|
|
116
|
+
assert x.get_mutable_variables() == ['mean', 'prec']
|
|
117
|
+
assert x.get_conditioning_variables() == ['a', 'b']
|
|
118
|
+
|
|
119
|
+
x = x(a=1, b=2)
|
|
120
|
+
|
|
121
|
+
assert np.allclose(x.mean, [1, 1, 1, 1])
|
|
122
|
+
assert np.allclose(x.prec, 10)
|
|
123
|
+
|
|
124
|
+
def test_RegularizedGaussian_conditioning_strength():
|
|
125
|
+
""" Test that conditioning the implicit regularized Gaussian works as expected """
|
|
126
|
+
|
|
127
|
+
x = cuqi.implicitprior.RegularizedGMRF(lambda a:a*np.ones(2**2),
|
|
128
|
+
prec = lambda b:5*b,
|
|
129
|
+
regularization = "tv",
|
|
130
|
+
strength = lambda c:c*2,
|
|
131
|
+
geometry = cuqi.geometry.Image2D((2,2)))
|
|
132
|
+
|
|
133
|
+
assert x.get_mutable_variables() == ['mean', 'prec', 'strength']
|
|
134
|
+
assert x.get_conditioning_variables() == ['a', 'b', 'c']
|
|
135
|
+
|
|
136
|
+
x = x(a=1, b=2, c=3)
|
|
137
|
+
|
|
138
|
+
assert np.allclose(x.mean, [1, 1, 1, 1])
|
|
139
|
+
assert np.allclose(x.prec, 10)
|
|
140
|
+
assert np.allclose(x.strength, 6)
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/CUQIpy.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_joint_distribution.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_modifiedhalfnormal.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_smoothed_laplace.py
RENAMED
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/distribution/_truncated_normal.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_conjugate.py
RENAMED
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_conjugate_approx.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/experimental/mcmc/_utilities.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/sampler/_laplace_approximation.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/cuqi/utilities/_get_python_variable_name.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{cuqipy-1.2.0.post0.dev314 → cuqipy-1.2.0.post0.dev342}/tests/test_abstract_distribution_density.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|