CUQIpy 1.3.0.post0.dev298__tar.gz → 1.3.0.post0.dev362__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/CUQIpy.egg-info/PKG-INFO +1 -1
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/PKG-INFO +1 -1
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/_version.py +3 -3
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_distribution.py +24 -15
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/likelihood/_likelihood.py +1 -1
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/model/_model.py +212 -77
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_joint_distribution.py +92 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_likelihood.py +60 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_model.py +375 -8
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_posterior.py +45 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/CUQIpy.egg-info/SOURCES.txt +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/CUQIpy.egg-info/dependency_links.txt +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/CUQIpy.egg-info/requires.txt +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/CUQIpy.egg-info/top_level.txt +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/LICENSE +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/README.md +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/_messages.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/array/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/array/_array.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/config.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/_data.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/astronaut.npz +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/camera.npz +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/cat.npz +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/cookie.png +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/data/satellite.mat +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/density/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/density/_density.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/diagnostics.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_beta.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_cauchy.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_cmrf.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_custom.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_gamma.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_gaussian.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_gmrf.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_inverse_gamma.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_joint_distribution.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_laplace.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_lmrf.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_lognormal.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_modifiedhalfnormal.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_normal.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_posterior.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_smoothed_laplace.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_truncated_normal.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/distribution/_uniform.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/_recommender.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/algebra/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/algebra/_ast.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/algebra/_orderedset.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/algebra/_randomvariable.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/geometry/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/geometry/_productgeometry.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_conjugate.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_conjugate_approx.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_cwmh.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_direct.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_gibbs.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_hmc.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_laplace_approximation.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_mh.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_pcn.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_rto.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/experimental/mcmc/_sampler.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/geometry/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/geometry/_geometry.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/implicitprior/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/implicitprior/_regularizedUnboundedUniform.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/implicitprior/_restorator.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/likelihood/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/model/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/operator/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/operator/_operator.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/pde/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/pde/_pde.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/problem/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/problem/_problem.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_conjugate.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_conjugate_approx.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_cwmh.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_gibbs.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_hmc.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_langevin_algorithm.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_laplace_approximation.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_mh.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_pcn.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_rto.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/sampler/_sampler.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/samples/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/samples/_samples.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/solver/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/solver/_solver.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/testproblem/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/testproblem/_testproblem.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/utilities/__init__.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/utilities/_get_python_variable_name.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/cuqi/utilities/_utilities.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/pyproject.toml +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/requirements.txt +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/setup.cfg +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/setup.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_MRFs.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_abstract_distribution_density.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_bayesian_inversion.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_density.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_distribution.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_distributions_shape.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_geometry.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_implicit_priors.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_pde.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_problem.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_sampler.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_samples.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_solver.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_testproblem.py +0 -0
- {cuqipy-1.3.0.post0.dev298 → cuqipy-1.3.0.post0.dev362}/tests/test_utilities.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.3.0.post0.
|
|
3
|
+
Version: 1.3.0.post0.dev362
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.3.0.post0.
|
|
3
|
+
Version: 1.3.0.post0.dev362
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2025-09-
|
|
11
|
+
"date": "2025-09-12T10:23:05+0300",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.3.0.post0.
|
|
14
|
+
"full-revisionid": "f6a73b0b32186614fe5451781567c1abcd48452a",
|
|
15
|
+
"version": "1.3.0.post0.dev362"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -105,7 +105,7 @@ class Distribution(Density, ABC):
|
|
|
105
105
|
f"Inconsistent distribution geometry attribute {self._geometry} and inferred "
|
|
106
106
|
f"dimension from distribution variables {inferred_dim}."
|
|
107
107
|
)
|
|
108
|
-
|
|
108
|
+
|
|
109
109
|
# If Geometry dimension is None, update it with the inferred dimension
|
|
110
110
|
if inferred_dim and self._geometry.par_dim is None:
|
|
111
111
|
self.geometry = inferred_dim
|
|
@@ -117,7 +117,7 @@ class Distribution(Density, ABC):
|
|
|
117
117
|
# We do not use self.name to potentially infer it from python stack.
|
|
118
118
|
if self._name:
|
|
119
119
|
self._geometry._variable_name = self._name
|
|
120
|
-
|
|
120
|
+
|
|
121
121
|
return self._geometry
|
|
122
122
|
|
|
123
123
|
@geometry.setter
|
|
@@ -160,7 +160,7 @@ class Distribution(Density, ABC):
|
|
|
160
160
|
f"{self.logd.__qualname__}: To evaluate the log density all conditioning variables and main"
|
|
161
161
|
f" parameter must be specified. Conditioning variables are: {cond_vars}"
|
|
162
162
|
)
|
|
163
|
-
|
|
163
|
+
|
|
164
164
|
# Check if all conditioning variables are specified
|
|
165
165
|
all_cond_vars_specified = all([key in kwargs for key in cond_vars])
|
|
166
166
|
if not all_cond_vars_specified:
|
|
@@ -168,7 +168,7 @@ class Distribution(Density, ABC):
|
|
|
168
168
|
f"{self.logd.__qualname__}: To evaluate the log density all conditioning variables must be"
|
|
169
169
|
f" specified. Conditioning variables are: {cond_vars}"
|
|
170
170
|
)
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
# Extract exactly the conditioning variables from kwargs
|
|
173
173
|
cond_kwargs = {key: kwargs[key] for key in cond_vars}
|
|
174
174
|
|
|
@@ -186,7 +186,7 @@ class Distribution(Density, ABC):
|
|
|
186
186
|
# Not conditional distribution, simply evaluate log density directly
|
|
187
187
|
else:
|
|
188
188
|
return super().logd(*args, **kwargs)
|
|
189
|
-
|
|
189
|
+
|
|
190
190
|
def _logd(self, *args):
|
|
191
191
|
return self.logpdf(*args) # Currently all distributions implement logpdf so we simply call this method.
|
|
192
192
|
|
|
@@ -216,7 +216,7 @@ class Distribution(Density, ABC):
|
|
|
216
216
|
# Get samples from the distribution sample method
|
|
217
217
|
s = self._sample(N,*args,**kwargs)
|
|
218
218
|
|
|
219
|
-
#Store samples in cuqi samples object if more than 1 sample
|
|
219
|
+
# Store samples in cuqi samples object if more than 1 sample
|
|
220
220
|
if N==1:
|
|
221
221
|
if len(s) == 1 and isinstance(s,np.ndarray): #Extract single value from numpy array
|
|
222
222
|
s = s.ravel()[0]
|
|
@@ -264,7 +264,7 @@ class Distribution(Density, ABC):
|
|
|
264
264
|
# Go through every mutable variable and assign value from kwargs if present
|
|
265
265
|
for var_key in mutable_vars:
|
|
266
266
|
|
|
267
|
-
#If keyword directly specifies new value of variable we simply reassign
|
|
267
|
+
# If keyword directly specifies new value of variable we simply reassign
|
|
268
268
|
if var_key in kwargs:
|
|
269
269
|
setattr(new_dist, var_key, kwargs.get(var_key))
|
|
270
270
|
processed_kwargs.add(var_key)
|
|
@@ -291,9 +291,18 @@ class Distribution(Density, ABC):
|
|
|
291
291
|
|
|
292
292
|
elif len(var_args)>0: #Some keywords found
|
|
293
293
|
# Define new partial function with partially defined args
|
|
294
|
-
|
|
294
|
+
if (
|
|
295
|
+
hasattr(var_val, "_supports_partial_eval")
|
|
296
|
+
and var_val._supports_partial_eval
|
|
297
|
+
):
|
|
298
|
+
func = var_val(**var_args)
|
|
299
|
+
else:
|
|
300
|
+
# If the callable does not support partial evaluation,
|
|
301
|
+
# we use the partial function to set the variable
|
|
302
|
+
func = partial(var_val, **var_args)
|
|
303
|
+
|
|
295
304
|
setattr(new_dist, var_key, func)
|
|
296
|
-
|
|
305
|
+
|
|
297
306
|
# Store processed keywords
|
|
298
307
|
processed_kwargs.update(var_args.keys())
|
|
299
308
|
|
|
@@ -329,7 +338,7 @@ class Distribution(Density, ABC):
|
|
|
329
338
|
|
|
330
339
|
def get_conditioning_variables(self):
|
|
331
340
|
"""Return the conditioning variables of this distribution (if any)."""
|
|
332
|
-
|
|
341
|
+
|
|
333
342
|
# Get all mutable variables
|
|
334
343
|
mutable_vars = self.get_mutable_variables()
|
|
335
344
|
|
|
@@ -338,7 +347,7 @@ class Distribution(Density, ABC):
|
|
|
338
347
|
|
|
339
348
|
# Add any variables defined through callable functions
|
|
340
349
|
cond_vars += get_indirect_variables(self)
|
|
341
|
-
|
|
350
|
+
|
|
342
351
|
return cond_vars
|
|
343
352
|
|
|
344
353
|
def get_mutable_variables(self):
|
|
@@ -347,10 +356,10 @@ class Distribution(Density, ABC):
|
|
|
347
356
|
# If mutable variables are already cached, return them
|
|
348
357
|
if hasattr(self, '_mutable_vars'):
|
|
349
358
|
return self._mutable_vars
|
|
350
|
-
|
|
359
|
+
|
|
351
360
|
# Define list of ignored attributes and properties
|
|
352
361
|
ignore_vars = ['name', 'is_symmetric', 'geometry', 'dim']
|
|
353
|
-
|
|
362
|
+
|
|
354
363
|
# Get public attributes
|
|
355
364
|
attributes = get_writeable_attributes(self)
|
|
356
365
|
|
|
@@ -396,7 +405,7 @@ class Distribution(Density, ABC):
|
|
|
396
405
|
raise ValueError(f"{self._condition.__qualname__}: {ordered_keys[index]} passed as both argument and keyword argument.\nArguments follow the listed conditioning variable order: {self.get_conditioning_variables()}")
|
|
397
406
|
kwargs[ordered_keys[index]] = arg
|
|
398
407
|
return kwargs
|
|
399
|
-
|
|
408
|
+
|
|
400
409
|
def _check_geometry_consistency(self):
|
|
401
410
|
""" Checks that the geometry of the distribution is consistent by calling the geometry property. Should be called at the end of __init__ of subclasses. """
|
|
402
411
|
self.geometry
|
|
@@ -411,4 +420,4 @@ class Distribution(Density, ABC):
|
|
|
411
420
|
def rv(self):
|
|
412
421
|
""" Return a random variable object representing the distribution. """
|
|
413
422
|
from cuqi.experimental.algebra import RandomVariable
|
|
414
|
-
return RandomVariable(self)
|
|
423
|
+
return RandomVariable(self)
|
|
@@ -212,4 +212,4 @@ class UserDefinedLikelihood(object):
|
|
|
212
212
|
return get_non_default_args(self.logpdf_func)
|
|
213
213
|
|
|
214
214
|
def __repr__(self) -> str:
|
|
215
|
-
return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())
|
|
215
|
+
return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())
|
|
@@ -132,6 +132,10 @@ class Model(object):
|
|
|
132
132
|
print(model(1, 1))
|
|
133
133
|
print(model.gradient(np.array([1]), 1, 1))
|
|
134
134
|
"""
|
|
135
|
+
|
|
136
|
+
_supports_partial_eval = True
|
|
137
|
+
"""Flag indicating that partial evaluation of Model objects is supported, i.e., calling the model object with only some of the inputs specified returns a model that can be called with the remaining inputs."""
|
|
138
|
+
|
|
135
139
|
def __init__(self, forward, range_geometry, domain_geometry, gradient=None, jacobian=None):
|
|
136
140
|
|
|
137
141
|
# Check if input is callable
|
|
@@ -311,7 +315,12 @@ class Model(object):
|
|
|
311
315
|
"Gradient needs to be callable function or tuple of callable functions."
|
|
312
316
|
)
|
|
313
317
|
|
|
314
|
-
expected_func_non_default_args =
|
|
318
|
+
expected_func_non_default_args = (
|
|
319
|
+
self._non_default_args
|
|
320
|
+
if not hasattr(self, "_original_non_default_args")
|
|
321
|
+
else self._original_non_default_args
|
|
322
|
+
)
|
|
323
|
+
|
|
315
324
|
if func_type.lower() == "gradient":
|
|
316
325
|
# prepend 'direction' to the expected gradient non default args
|
|
317
326
|
expected_func_non_default_args = [
|
|
@@ -613,52 +622,43 @@ class Model(object):
|
|
|
613
622
|
if non_default_args is None:
|
|
614
623
|
non_default_args = self._non_default_args
|
|
615
624
|
|
|
616
|
-
#
|
|
625
|
+
# Either args or kwargs can be provided but not both
|
|
626
|
+
if len(args) > 0 and len(kwargs) > 0:
|
|
627
|
+
raise ValueError(
|
|
628
|
+
"The "
|
|
629
|
+
+ map_name.lower()
|
|
630
|
+
+ " input is specified both as positional and keyword arguments. This is not supported."
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
len_input = len(args) + len(kwargs)
|
|
634
|
+
|
|
635
|
+
# If partial evaluation, make sure input is not of type Samples
|
|
636
|
+
if len_input < len(non_default_args):
|
|
637
|
+
# If the argument is a Sample object, splitting or partial
|
|
638
|
+
# evaluation of the model is not supported
|
|
639
|
+
temp_args = args if len(args) > 0 else list(kwargs.values())
|
|
640
|
+
if any(isinstance(arg, Samples) for arg in temp_args):
|
|
641
|
+
raise ValueError(("When using Samples objects as input, the"
|
|
642
|
+
+" user should provide a Samples object for"
|
|
643
|
+
+f" each non_default_args {non_default_args}"
|
|
644
|
+
+" of the model. That is, partial evaluation"
|
|
645
|
+
+" or splitting is not supported for input"
|
|
646
|
+
+" of type Samples."))
|
|
647
|
+
|
|
648
|
+
# If args are given, add them to kwargs
|
|
617
649
|
if len(args) > 0:
|
|
618
|
-
if len(kwargs) > 0:
|
|
619
|
-
raise ValueError(
|
|
620
|
-
"The "
|
|
621
|
-
+ map_name.lower()
|
|
622
|
-
+ " input is specified both as positional and keyword arguments. This is not supported."
|
|
623
|
-
)
|
|
624
650
|
|
|
625
|
-
appending_error_message = ""
|
|
626
651
|
# Check if the input is for multiple input case and is stacked,
|
|
627
652
|
# then split it
|
|
628
|
-
if len(args)
|
|
629
|
-
|
|
630
|
-
if isinstance(args[0], Samples):
|
|
631
|
-
raise ValueError(
|
|
632
|
-
"The "
|
|
633
|
-
+ map_name.lower()
|
|
634
|
-
+ f" input is specified by a Samples object that cannot be split into multiple arguments corresponding to the non_default_args {non_default_args}."
|
|
635
|
-
)
|
|
636
|
-
split_succeeded, split_args = self._is_stacked_args(*args, is_par=is_par)
|
|
637
|
-
if split_succeeded:
|
|
638
|
-
args = split_args
|
|
639
|
-
else:
|
|
640
|
-
appending_error_message = (
|
|
641
|
-
" Additionally, the "
|
|
642
|
-
+ map_name.lower()
|
|
643
|
-
+ f" input is specified by a single argument that cannot be split into multiple arguments matching the expected non_default_args {non_default_args}."
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
# Check if the number of args does not match the number of
|
|
647
|
-
# non_default_args of the model
|
|
648
|
-
if len(args) != len(non_default_args):
|
|
649
|
-
raise ValueError(
|
|
650
|
-
"The number of positional arguments does not match the number of non-default arguments of the "
|
|
651
|
-
+ map_name.lower()
|
|
652
|
-
+ "."
|
|
653
|
-
+ appending_error_message
|
|
654
|
-
)
|
|
653
|
+
if len(args) < len(non_default_args):
|
|
654
|
+
args = self._split_in_case_of_stacked_args(*args, is_par=is_par)
|
|
655
655
|
|
|
656
656
|
# Add args to kwargs following the order of non_default_args
|
|
657
657
|
for idx, arg in enumerate(args):
|
|
658
658
|
kwargs[non_default_args[idx]] = arg
|
|
659
|
-
|
|
659
|
+
|
|
660
660
|
# Check kwargs matches non_default_args
|
|
661
|
-
if set(list(kwargs.keys()))
|
|
661
|
+
if not (set(list(kwargs.keys())) <= set(non_default_args)):
|
|
662
662
|
if map_name == "gradient":
|
|
663
663
|
error_msg = f"The gradient input is specified by a direction and keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the model {non_default_args}."
|
|
664
664
|
else:
|
|
@@ -673,53 +673,41 @@ class Model(object):
|
|
|
673
673
|
raise ValueError(error_msg)
|
|
674
674
|
|
|
675
675
|
# Make sure order of kwargs is the same as non_default_args
|
|
676
|
-
kwargs = {k: kwargs[k] for k in non_default_args}
|
|
676
|
+
kwargs = {k: kwargs[k] for k in non_default_args if k in kwargs}
|
|
677
677
|
|
|
678
678
|
return kwargs
|
|
679
679
|
|
|
680
|
-
def
|
|
681
|
-
"""Private function that checks if the input
|
|
682
|
-
and splits
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
if len(args) > 1:
|
|
686
|
-
return False, args
|
|
687
|
-
|
|
688
|
-
# Type of args should be parameter
|
|
689
|
-
if not is_par:
|
|
690
|
-
return False, args
|
|
680
|
+
def _split_in_case_of_stacked_args(self, *args, is_par=True):
|
|
681
|
+
"""Private function that checks if the input args is a stacked
|
|
682
|
+
CUQIarray or numpy array and splits it into multiple arguments based on
|
|
683
|
+
the domain geometry of the model. Otherwise, it returns the input args
|
|
684
|
+
unchanged."""
|
|
691
685
|
|
|
692
|
-
#
|
|
686
|
+
# Check conditions for splitting and split if all conditions are met
|
|
693
687
|
is_CUQIarray = isinstance(args[0], CUQIarray)
|
|
694
688
|
is_numpy_array = isinstance(args[0], np.ndarray)
|
|
695
|
-
if not is_CUQIarray and not is_numpy_array:
|
|
696
|
-
return False, args
|
|
697
689
|
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
if is_CUQIarray:
|
|
713
|
-
split_args = [
|
|
714
|
-
CUQIarray(arg, is_par=True, geometry=self.domain_geometry.geometries[i])
|
|
715
|
-
for i, arg in enumerate(split_args)
|
|
716
|
-
]
|
|
690
|
+
if ((is_CUQIarray or is_numpy_array) and
|
|
691
|
+
is_par and
|
|
692
|
+
len(args) == 1 and
|
|
693
|
+
args[0].shape == (self.domain_dim,) and
|
|
694
|
+
isinstance(self.domain_geometry, cuqi.experimental.geometry._ProductGeometry)):
|
|
695
|
+
# Split the stacked input
|
|
696
|
+
split_args = np.split(args[0], self.domain_geometry.stacked_par_split_indices)
|
|
697
|
+
# Convert split args to CUQIarray if input is CUQIarray
|
|
698
|
+
if is_CUQIarray:
|
|
699
|
+
split_args = [
|
|
700
|
+
CUQIarray(arg, is_par=True, geometry=self.domain_geometry.geometries[i])
|
|
701
|
+
for i, arg in enumerate(split_args)
|
|
702
|
+
]
|
|
703
|
+
return split_args
|
|
717
704
|
|
|
718
|
-
|
|
705
|
+
else:
|
|
706
|
+
return args
|
|
719
707
|
|
|
720
708
|
def forward(self, *args, is_par=True, **kwargs):
|
|
721
709
|
""" Forward function of the model.
|
|
722
|
-
|
|
710
|
+
|
|
723
711
|
Forward converts the input to function values (if needed) using the domain geometry of the model. Then it applies the forward operator to the function values and converts the output to parameters using the range geometry of the model.
|
|
724
712
|
|
|
725
713
|
Parameters
|
|
@@ -733,7 +721,7 @@ class Model(object):
|
|
|
733
721
|
If True, the inputs in `args` or `kwargs` are assumed to be parameters.
|
|
734
722
|
If False, the inputs in `args` or `kwargs` are assumed to be function values.
|
|
735
723
|
If `is_par` is a tuple of bools, the inputs are assumed to be parameters or function values based on the corresponding boolean value in the tuple.
|
|
736
|
-
|
|
724
|
+
|
|
737
725
|
**kwargs : keyword arguments
|
|
738
726
|
keyword arguments for the forward operator. The forward operator input can be specified as either positional arguments or keyword arguments but not both.
|
|
739
727
|
|
|
@@ -750,19 +738,31 @@ class Model(object):
|
|
|
750
738
|
kwargs = self._parse_args_add_to_kwargs(
|
|
751
739
|
*args, **kwargs, is_par=is_par, map_name="model"
|
|
752
740
|
)
|
|
753
|
-
|
|
754
|
-
# extract args from kwargs
|
|
741
|
+
# Extract args from kwargs
|
|
755
742
|
args = list(kwargs.values())
|
|
756
743
|
|
|
744
|
+
if len(kwargs) == 0:
|
|
745
|
+
return self
|
|
746
|
+
|
|
747
|
+
partial_arguments = len(kwargs) < len(self._non_default_args)
|
|
748
|
+
|
|
757
749
|
# If input is a distribution, we simply change the parameter name of
|
|
758
750
|
# model to match the distribution name
|
|
759
751
|
if all(isinstance(x, cuqi.distribution.Distribution)
|
|
760
752
|
for x in kwargs.values()):
|
|
753
|
+
if partial_arguments:
|
|
754
|
+
raise ValueError(
|
|
755
|
+
"Partial evaluation of the model is not supported for distributions."
|
|
756
|
+
)
|
|
761
757
|
return self._handle_case_when_model_input_is_distributions(kwargs)
|
|
762
758
|
|
|
763
759
|
# If input is a random variable, we handle it separately
|
|
764
760
|
elif all(isinstance(x, cuqi.experimental.algebra.RandomVariable)
|
|
765
761
|
for x in kwargs.values()):
|
|
762
|
+
if partial_arguments:
|
|
763
|
+
raise ValueError(
|
|
764
|
+
"Partial evaluation of the model is not supported for random variables."
|
|
765
|
+
)
|
|
766
766
|
return self._handle_case_when_model_input_is_random_variables(kwargs)
|
|
767
767
|
|
|
768
768
|
# If input is a Node from internal abstract syntax tree, we let the Node handle the operation
|
|
@@ -772,6 +772,21 @@ class Model(object):
|
|
|
772
772
|
elif any(isinstance(args_i, cuqi.experimental.algebra.Node) for args_i in args):
|
|
773
773
|
return NotImplemented
|
|
774
774
|
|
|
775
|
+
# if input is partial, we create a new model with the partial input
|
|
776
|
+
if partial_arguments:
|
|
777
|
+
# Create is_par_partial from the is_par to contain only the relevant parts
|
|
778
|
+
if isinstance(is_par, (list, tuple)):
|
|
779
|
+
is_par_partial = tuple(
|
|
780
|
+
is_par[i]
|
|
781
|
+
for i in range(self.number_of_inputs)
|
|
782
|
+
if self._non_default_args[i] in kwargs.keys()
|
|
783
|
+
)
|
|
784
|
+
else:
|
|
785
|
+
is_par_partial = is_par
|
|
786
|
+
# Build a partial model with the given kwargs
|
|
787
|
+
partial_model = self._build_partial_model(kwargs, is_par_partial)
|
|
788
|
+
return partial_model
|
|
789
|
+
|
|
775
790
|
# Else we apply the forward operator
|
|
776
791
|
# if model has _original_non_default_args, we use it to replace the
|
|
777
792
|
# kwargs keys so that it matches self._forward_func signature
|
|
@@ -797,6 +812,126 @@ class Model(object):
|
|
|
797
812
|
else:
|
|
798
813
|
return False
|
|
799
814
|
|
|
815
|
+
def _build_partial_model(self, kwargs, is_par):
|
|
816
|
+
"""Private function that builds a partial model substituting the given
|
|
817
|
+
keyword arguments with their values. The created partial model will have
|
|
818
|
+
as inputs the non-default arguments that are not in the kwargs."""
|
|
819
|
+
|
|
820
|
+
# Extract args from kwargs
|
|
821
|
+
args = list(kwargs.values())
|
|
822
|
+
|
|
823
|
+
# Define original_non_default_args which represents the complete list of
|
|
824
|
+
# non-default arguments of the forward function.
|
|
825
|
+
original_non_default_args = (
|
|
826
|
+
self._original_non_default_args
|
|
827
|
+
if hasattr(self, "_original_non_default_args")
|
|
828
|
+
else self._non_default_args
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
if hasattr(self, "_original_non_default_args"):
|
|
832
|
+
# Split the _original_non_default_args into two lists:
|
|
833
|
+
# 1. reduced_original_non_default_args: the _original_non_default_args
|
|
834
|
+
# corresponding to the _non_default_args that are not in kwargs
|
|
835
|
+
# 2. substituted_non_default_args: the _original_non_default_args
|
|
836
|
+
# corresponding to the _non_default_args that are in kwargs
|
|
837
|
+
reduced_original_non_default_args = [
|
|
838
|
+
original_non_default_args[i]
|
|
839
|
+
for i in range(self.number_of_inputs)
|
|
840
|
+
if self._non_default_args[i] not in kwargs.keys()
|
|
841
|
+
]
|
|
842
|
+
substituted_non_default_args = [
|
|
843
|
+
original_non_default_args[i]
|
|
844
|
+
for i in range(self.number_of_inputs)
|
|
845
|
+
if self._non_default_args[i] in kwargs.keys()
|
|
846
|
+
]
|
|
847
|
+
# Replace the keys in kwargs with the substituted_non_default_args
|
|
848
|
+
# so that the kwargs match the signature of the _forward_func
|
|
849
|
+
kwargs = {k: v for k, v in zip(substituted_non_default_args, args)}
|
|
850
|
+
|
|
851
|
+
# Create a partial domain geometry with the geometries corresponding
|
|
852
|
+
# to the non-default arguments that are not in kwargs (remaining
|
|
853
|
+
# unspecified inputs)
|
|
854
|
+
partial_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
|
|
855
|
+
*[
|
|
856
|
+
self.domain_geometry.geometries[i]
|
|
857
|
+
for i in range(self.number_of_inputs)
|
|
858
|
+
if original_non_default_args[i] not in kwargs.keys()
|
|
859
|
+
]
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
if len(partial_domain_geometry.geometries) == 1:
|
|
863
|
+
partial_domain_geometry = partial_domain_geometry.geometries[0]
|
|
864
|
+
|
|
865
|
+
# Create a domain geometry with the geometries corresponding to the
|
|
866
|
+
# non-default arguments that are specified
|
|
867
|
+
substituted_domain_geometry = cuqi.experimental.geometry._ProductGeometry(
|
|
868
|
+
*[
|
|
869
|
+
self.domain_geometry.geometries[i]
|
|
870
|
+
for i in range(self.number_of_inputs)
|
|
871
|
+
if original_non_default_args[i] in kwargs.keys()
|
|
872
|
+
]
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
if len(substituted_domain_geometry.geometries) == 1:
|
|
876
|
+
substituted_domain_geometry = substituted_domain_geometry.geometries[0]
|
|
877
|
+
|
|
878
|
+
# Create new model with partial input
|
|
879
|
+
# First, we convert the input to function values
|
|
880
|
+
kwargs = self._2fun(geometry=substituted_domain_geometry, is_par=is_par, **kwargs)
|
|
881
|
+
|
|
882
|
+
# Second, we create a partial function for the forward operator
|
|
883
|
+
partial_forward = partial(self._forward_func, **kwargs)
|
|
884
|
+
|
|
885
|
+
# Third, if applicable, we create a partial function for the gradient
|
|
886
|
+
if isinstance(self._gradient_func, tuple):
|
|
887
|
+
# If gradient is a tuple, we create a partial function for each
|
|
888
|
+
# gradient function in the tuple
|
|
889
|
+
partial_gradient = tuple(
|
|
890
|
+
(
|
|
891
|
+
partial(self._gradient_func[i], **kwargs)
|
|
892
|
+
if self._gradient_func[i] is not None
|
|
893
|
+
else None
|
|
894
|
+
)
|
|
895
|
+
for i in range(self.number_of_inputs)
|
|
896
|
+
if original_non_default_args[i] not in kwargs.keys()
|
|
897
|
+
)
|
|
898
|
+
if len(partial_gradient) == 1:
|
|
899
|
+
partial_gradient = partial_gradient[0]
|
|
900
|
+
|
|
901
|
+
elif callable(self._gradient_func):
|
|
902
|
+
raise NotImplementedError(
|
|
903
|
+
"Partial forward model is only supported for gradient/jacobian functions that are tuples of callable functions."
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
else:
|
|
907
|
+
partial_gradient = None
|
|
908
|
+
|
|
909
|
+
# Lastly, we create the partial model with the partial forward
|
|
910
|
+
# operator (we set the gradient function later)
|
|
911
|
+
partial_model = Model(
|
|
912
|
+
forward=partial_forward,
|
|
913
|
+
range_geometry=self.range_geometry,
|
|
914
|
+
domain_geometry=partial_domain_geometry,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
# Set the _original_non_default_args (if applicable) and
|
|
918
|
+
# _stored_non_default_args of the partial model
|
|
919
|
+
if hasattr(self, "_original_non_default_args"):
|
|
920
|
+
partial_model._original_non_default_args = reduced_original_non_default_args
|
|
921
|
+
partial_model._stored_non_default_args = [
|
|
922
|
+
self._non_default_args[i]
|
|
923
|
+
for i in range(self.number_of_inputs)
|
|
924
|
+
if original_non_default_args[i] not in kwargs.keys()
|
|
925
|
+
]
|
|
926
|
+
|
|
927
|
+
# Set the gradient function of the partial model
|
|
928
|
+
partial_model._check_correct_gradient_jacobian_form(
|
|
929
|
+
partial_gradient, "gradient"
|
|
930
|
+
)
|
|
931
|
+
partial_model._gradient_func = partial_gradient
|
|
932
|
+
|
|
933
|
+
return partial_model
|
|
934
|
+
|
|
800
935
|
def _handle_case_when_model_input_is_distributions(self, kwargs):
|
|
801
936
|
"""Private function that handles the case of the input being a
|
|
802
937
|
distribution or multiple distributions."""
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import cuqi
|
|
2
2
|
import numpy as np
|
|
3
3
|
import pytest
|
|
4
|
+
from .test_model import MultipleInputTestModel
|
|
4
5
|
|
|
5
6
|
def test_joint_dist_dim_geometry():
|
|
6
7
|
""" Test the dimension and geometry properties of a joint distribution """
|
|
@@ -393,3 +394,94 @@ def test_logd_consistency_when_conditioning(joint, variables):
|
|
|
393
394
|
|
|
394
395
|
# Add current variable to the variables that need to be conditioned
|
|
395
396
|
cond_vars[key] = value
|
|
397
|
+
|
|
398
|
+
def test_joint_distribution_with_multiple_inputs_model_has_correct_parameter_names():
|
|
399
|
+
"""Test that the joint distribution based on model with multiple inputs has
|
|
400
|
+
correct parameter names."""
|
|
401
|
+
|
|
402
|
+
test_model = MultipleInputTestModel.helper_build_three_input_test_model()
|
|
403
|
+
model = cuqi.model.Model(
|
|
404
|
+
test_model.forward_map,
|
|
405
|
+
gradient=test_model.gradient_form2,
|
|
406
|
+
domain_geometry=test_model.domain_geometry,
|
|
407
|
+
range_geometry=test_model.range_geometry,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Create priors
|
|
411
|
+
x_dist = cuqi.distribution.Gaussian(
|
|
412
|
+
mean=np.zeros(3),
|
|
413
|
+
cov=np.eye(3))
|
|
414
|
+
y_dist = cuqi.distribution.Gaussian(
|
|
415
|
+
mean=np.zeros(2),
|
|
416
|
+
cov=np.eye(2))
|
|
417
|
+
z_dist = cuqi.distribution.Gaussian(
|
|
418
|
+
mean=np.zeros(3),
|
|
419
|
+
cov=np.eye(3))
|
|
420
|
+
|
|
421
|
+
# Create data distribution
|
|
422
|
+
data_dist = cuqi.distribution.Gaussian(
|
|
423
|
+
mean=model(x_dist, y_dist, z_dist), cov = 1.0)
|
|
424
|
+
|
|
425
|
+
# Create likelihood
|
|
426
|
+
likelihood = data_dist(data_dist = np.array([2,2,3]))
|
|
427
|
+
|
|
428
|
+
x_val = np.array([1, 2, 3])
|
|
429
|
+
y_val = np.array([4, 5])
|
|
430
|
+
z_val = np.array([6, 7, 8])
|
|
431
|
+
|
|
432
|
+
posterior = cuqi.distribution.JointDistribution(
|
|
433
|
+
likelihood,
|
|
434
|
+
x_dist,
|
|
435
|
+
y_dist,
|
|
436
|
+
z_dist
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Ensure correct parameter names are returned for joint distribution with likelihood
|
|
440
|
+
assert posterior.get_parameter_names() == ['x_dist', 'y_dist', 'z_dist']
|
|
441
|
+
|
|
442
|
+
assert posterior(x_dist=x_val).get_parameter_names() == ['y_dist', 'z_dist']
|
|
443
|
+
assert posterior(y_dist=y_val).get_parameter_names() == ['x_dist', 'z_dist']
|
|
444
|
+
assert posterior(z_dist=z_val).get_parameter_names() == ['x_dist', 'y_dist']
|
|
445
|
+
|
|
446
|
+
assert posterior(y_dist=y_val, z_dist=z_val).get_parameter_names() == ['x_dist']
|
|
447
|
+
assert posterior(x_dist=x_val, z_dist=z_val).get_parameter_names() == ['y_dist']
|
|
448
|
+
assert posterior(x_dist=x_val, y_dist=y_val).get_parameter_names() == ['z_dist']
|
|
449
|
+
|
|
450
|
+
assert posterior(x_dist=x_val, y_dist=y_val, z_dist=z_val).get_parameter_names() == []
|
|
451
|
+
|
|
452
|
+
joint_dist = cuqi.distribution.JointDistribution(
|
|
453
|
+
data_dist,
|
|
454
|
+
x_dist,
|
|
455
|
+
y_dist,
|
|
456
|
+
z_dist
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Ensure correct parameter names are returned for joint distribution with data distribution
|
|
460
|
+
assert joint_dist.get_parameter_names() == ['data_dist', 'x_dist', 'y_dist', 'z_dist']
|
|
461
|
+
assert joint_dist(x_dist=x_val).get_parameter_names() == ['data_dist', 'y_dist', 'z_dist']
|
|
462
|
+
assert joint_dist(y_dist=y_val).get_parameter_names() == ['data_dist', 'x_dist', 'z_dist']
|
|
463
|
+
assert joint_dist(z_dist=z_val).get_parameter_names() == ['data_dist', 'x_dist', 'y_dist']
|
|
464
|
+
assert joint_dist(data_dist=np.array([2,2,3])).get_parameter_names() == ['x_dist', 'y_dist', 'z_dist']
|
|
465
|
+
|
|
466
|
+
assert joint_dist(x_dist=x_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['y_dist', 'z_dist']
|
|
467
|
+
assert joint_dist(y_dist=y_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['x_dist', 'z_dist']
|
|
468
|
+
assert joint_dist(z_dist=z_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['x_dist', 'y_dist']
|
|
469
|
+
assert joint_dist(x_dist=x_val, y_dist=y_val).get_parameter_names() == ['data_dist', 'z_dist']
|
|
470
|
+
assert joint_dist(x_dist=x_val, z_dist=z_val).get_parameter_names() == ['data_dist', 'y_dist']
|
|
471
|
+
assert joint_dist(y_dist=y_val, z_dist=z_val).get_parameter_names() == ['data_dist', 'x_dist']
|
|
472
|
+
|
|
473
|
+
assert joint_dist(x_dist=x_val, y_dist=y_val, z_dist=z_val).get_parameter_names() == ['data_dist']
|
|
474
|
+
assert joint_dist(x_dist=x_val, y_dist=y_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['z_dist']
|
|
475
|
+
assert joint_dist(x_dist=x_val, z_dist=z_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['y_dist']
|
|
476
|
+
assert joint_dist(y_dist=y_val, z_dist=z_val, data_dist=np.array([2,2,3])).get_parameter_names() == ['x_dist']
|
|
477
|
+
|
|
478
|
+
# Ensure correct parameter names are returned for underlying likelihood
|
|
479
|
+
assert joint_dist(data_dist=np.array([2,2,3]))._likelihoods[0].get_parameter_names() == ['x_dist', 'y_dist', 'z_dist']
|
|
480
|
+
|
|
481
|
+
assert joint_dist(x_dist=x_val, data_dist=np.array([2,2,3]))._likelihoods[0].get_parameter_names() == ['y_dist', 'z_dist']
|
|
482
|
+
assert joint_dist(y_dist=y_val, data_dist=np.array([2,2,3]))._likelihoods[0].get_parameter_names() == ['x_dist', 'z_dist']
|
|
483
|
+
assert joint_dist(z_dist=z_val, data_dist=np.array([2,2,3]))._likelihoods[0].get_parameter_names() == ['x_dist', 'y_dist']
|
|
484
|
+
|
|
485
|
+
assert joint_dist(x_dist=x_val, y_dist=y_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['z_dist']
|
|
486
|
+
assert joint_dist(x_dist=x_val, z_dist=z_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['y_dist']
|
|
487
|
+
assert joint_dist(y_dist=y_val, z_dist=z_val, data_dist=np.array([2,2,3])).likelihood.get_parameter_names() == ['x_dist']
|