CUQIpy 1.0.0.post0.dev84__tar.gz → 1.0.0.post0.dev127__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (108) hide show
  1. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/CUQIpy.egg-info/PKG-INFO +2 -2
  2. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/CUQIpy.egg-info/SOURCES.txt +1 -0
  3. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/CUQIpy.egg-info/requires.txt +1 -1
  4. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/PKG-INFO +2 -2
  5. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/_version.py +3 -3
  6. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/__init__.py +1 -0
  7. CUQIpy-1.0.0.post0.dev127/cuqi/experimental/mcmc/_rto.py +275 -0
  8. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/_sampler.py +5 -0
  9. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/samples/_samples.py +10 -3
  10. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/requirements.txt +1 -1
  11. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/CUQIpy.egg-info/dependency_links.txt +0 -0
  12. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/CUQIpy.egg-info/top_level.txt +0 -0
  13. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/LICENSE +0 -0
  14. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/README.md +0 -0
  15. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/__init__.py +0 -0
  16. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/_messages.py +0 -0
  17. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/array/__init__.py +0 -0
  18. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/array/_array.py +0 -0
  19. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/config.py +0 -0
  20. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/__init__.py +0 -0
  21. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/_data.py +0 -0
  22. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/astronaut.npz +0 -0
  23. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/camera.npz +0 -0
  24. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/cat.npz +0 -0
  25. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/cookie.png +0 -0
  26. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/data/satellite.mat +0 -0
  27. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/density/__init__.py +0 -0
  28. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/density/_density.py +0 -0
  29. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/diagnostics.py +0 -0
  30. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/__init__.py +0 -0
  31. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_beta.py +0 -0
  32. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_cauchy.py +0 -0
  33. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_cmrf.py +0 -0
  34. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_custom.py +0 -0
  35. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_distribution.py +0 -0
  36. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_gamma.py +0 -0
  37. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_gaussian.py +0 -0
  38. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_gmrf.py +0 -0
  39. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_inverse_gamma.py +0 -0
  40. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_joint_distribution.py +0 -0
  41. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_laplace.py +0 -0
  42. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_lmrf.py +0 -0
  43. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_lognormal.py +0 -0
  44. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_normal.py +0 -0
  45. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_posterior.py +0 -0
  46. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/distribution/_uniform.py +0 -0
  47. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/__init__.py +0 -0
  48. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/_cwmh.py +0 -0
  49. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/_langevin_algorithm.py +0 -0
  50. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/_mh.py +0 -0
  51. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/experimental/mcmc/_pcn.py +0 -0
  52. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/geometry/__init__.py +0 -0
  53. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/geometry/_geometry.py +0 -0
  54. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/implicitprior/__init__.py +0 -0
  55. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/implicitprior/_regularizedGMRF.py +0 -0
  56. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/implicitprior/_regularizedGaussian.py +0 -0
  57. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/likelihood/__init__.py +0 -0
  58. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/likelihood/_likelihood.py +0 -0
  59. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/model/__init__.py +0 -0
  60. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/model/_model.py +0 -0
  61. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/operator/__init__.py +0 -0
  62. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/operator/_operator.py +0 -0
  63. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/pde/__init__.py +0 -0
  64. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/pde/_pde.py +0 -0
  65. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/problem/__init__.py +0 -0
  66. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/problem/_problem.py +0 -0
  67. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/__init__.py +0 -0
  68. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_conjugate.py +0 -0
  69. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_conjugate_approx.py +0 -0
  70. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_cwmh.py +0 -0
  71. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_gibbs.py +0 -0
  72. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_hmc.py +0 -0
  73. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_langevin_algorithm.py +0 -0
  74. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_laplace_approximation.py +0 -0
  75. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_mh.py +0 -0
  76. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_pcn.py +0 -0
  77. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_rto.py +0 -0
  78. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/sampler/_sampler.py +0 -0
  79. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/samples/__init__.py +0 -0
  80. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/solver/__init__.py +0 -0
  81. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/solver/_solver.py +0 -0
  82. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/testproblem/__init__.py +0 -0
  83. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/testproblem/_testproblem.py +0 -0
  84. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/utilities/__init__.py +0 -0
  85. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/utilities/_get_python_variable_name.py +0 -0
  86. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/cuqi/utilities/_utilities.py +0 -0
  87. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/pyproject.toml +0 -0
  88. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/setup.cfg +0 -0
  89. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/setup.py +0 -0
  90. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_MRFs.py +0 -0
  91. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_abstract_distribution_density.py +0 -0
  92. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_bayesian_inversion.py +0 -0
  93. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_density.py +0 -0
  94. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_distribution.py +0 -0
  95. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_distributions_shape.py +0 -0
  96. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_geometry.py +0 -0
  97. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_implicit_priors.py +0 -0
  98. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_joint_distribution.py +0 -0
  99. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_likelihood.py +0 -0
  100. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_model.py +0 -0
  101. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_pde.py +0 -0
  102. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_posterior.py +0 -0
  103. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_problem.py +0 -0
  104. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_sampler.py +0 -0
  105. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_samples.py +0 -0
  106. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_solver.py +0 -0
  107. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_testproblem.py +0 -0
  108. {CUQIpy-1.0.0.post0.dev84 → CUQIpy-1.0.0.post0.dev127}/tests/test_utilities.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev84
3
+ Version: 1.0.0.post0.dev127
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -202,7 +202,7 @@ Description-Content-Type: text/markdown
202
202
  License-File: LICENSE
203
203
  Requires-Dist: matplotlib
204
204
  Requires-Dist: numpy>=1.17.0
205
- Requires-Dist: scipy
205
+ Requires-Dist: scipy<1.13
206
206
  Requires-Dist: arviz
207
207
 
208
208
  <div align="center">
@@ -47,6 +47,7 @@ cuqi/experimental/mcmc/_cwmh.py
47
47
  cuqi/experimental/mcmc/_langevin_algorithm.py
48
48
  cuqi/experimental/mcmc/_mh.py
49
49
  cuqi/experimental/mcmc/_pcn.py
50
+ cuqi/experimental/mcmc/_rto.py
50
51
  cuqi/experimental/mcmc/_sampler.py
51
52
  cuqi/geometry/__init__.py
52
53
  cuqi/geometry/_geometry.py
@@ -1,4 +1,4 @@
1
1
  matplotlib
2
2
  numpy>=1.17.0
3
- scipy
3
+ scipy<1.13
4
4
  arviz
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev84
3
+ Version: 1.0.0.post0.dev127
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -202,7 +202,7 @@ Description-Content-Type: text/markdown
202
202
  License-File: LICENSE
203
203
  Requires-Dist: matplotlib
204
204
  Requires-Dist: numpy>=1.17.0
205
- Requires-Dist: scipy
205
+ Requires-Dist: scipy<1.13
206
206
  Requires-Dist: arviz
207
207
 
208
208
  <div align="center">
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-03-15T10:58:55+0100",
11
+ "date": "2024-04-05T11:26:24+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "680cbe76e755e836b16b4ce98e7125d1d158fca2",
15
- "version": "1.0.0.post0.dev84"
14
+ "full-revisionid": "476cc08200034336b46144fac4c819f8298fa587",
15
+ "version": "1.0.0.post0.dev127"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -4,4 +4,5 @@ from ._sampler import SamplerNew, ProposalBasedSamplerNew
4
4
  from ._langevin_algorithm import ULANew, MALANew
5
5
  from ._mh import MHNew
6
6
  from ._pcn import pCNNew
7
+ from ._rto import LinearRTONew, RegularizedLinearRTONew
7
8
  from ._cwmh import CWMHNew
@@ -0,0 +1,275 @@
1
+ import scipy as sp
2
+ from scipy.linalg.interpolative import estimate_spectral_norm
3
+ from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
4
+ import numpy as np
5
+ import cuqi
6
+ from cuqi.solver import CGLS, FISTA
7
+ from cuqi.experimental.mcmc import SamplerNew
8
+ from cuqi.array import CUQIarray
9
+
10
+
11
+ class LinearRTONew(SamplerNew):
12
+ """
13
+ Linear RTO (Randomize-Then-Optimize) sampler.
14
+
15
+ Samples posterior related to the inverse problem with Gaussian likelihood and prior, and where the forward model is Linear.
16
+
17
+ Parameters
18
+ ------------
19
+ target : `cuqi.distribution.Posterior`, `cuqi.distribution.MultipleLikelihoodPosterior` or 5-dimensional tuple.
20
+ If target is of type cuqi.distribution.Posterior or cuqi.distribution.MultipleLikelihoodPosterior, it represents the posterior distribution.
21
+ If target is a 5-dimensional tuple, it assumes the following structure:
22
+ (data, model, L_sqrtprec, P_mean, P_sqrtrec)
23
+
24
+ Here:
25
+ data: is a m-dimensional numpy array containing the measured data.
26
+ model: is a m by n dimensional matrix or LinearModel representing the forward model.
27
+ L_sqrtprec: is the squareroot of the precision matrix of the Gaussian likelihood.
28
+ P_mean: is the prior mean.
29
+ P_sqrtprec: is the squareroot of the precision matrix of the Gaussian mean.
30
+
31
+ initial_point : `np.ndarray`
32
+ Initial point for the sampler. *Optional*.
33
+
34
+ maxit : int
35
+ Maximum number of iterations of the inner CGLS solver. *Optional*.
36
+
37
+ tol : float
38
+ Tolerance of the inner CGLS solver. *Optional*.
39
+
40
+ callback : callable, *Optional*
41
+ If set this function will be called after every sample.
42
+ The signature of the callback function is `callback(sample, sample_index)`,
43
+ where `sample` is the current sample and `sample_index` is the index of the sample.
44
+ An example is shown in demos/demo31_callback.py.
45
+
46
+ """
47
+ def __init__(self, target, initial_point=None, maxit=10, tol=1e-6, **kwargs):
48
+
49
+ super().__init__(target=target, initial_point=initial_point, **kwargs)
50
+
51
+ if initial_point is None: #TODO: Replace later with a getter
52
+ self.initial_point = np.zeros(self.dim)
53
+
54
+ self.current_point = self.initial_point
55
+ self._acc = [1] # TODO. Check if we need this
56
+
57
+ # Other parameters
58
+ self.maxit = maxit
59
+ self.tol = tol
60
+
61
+ @property
62
+ def prior(self):
63
+ return self.target.prior
64
+
65
+ @property
66
+ def likelihood(self):
67
+ return self.target.likelihood
68
+
69
+ @property
70
+ def likelihoods(self):
71
+ if isinstance(self.target, cuqi.distribution.Posterior):
72
+ return [self.target.likelihood]
73
+ elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
74
+ return self.target.likelihoods
75
+
76
+ @property
77
+ def model(self):
78
+ return self.target.model
79
+
80
+ @property
81
+ def data(self):
82
+ return self.target.data
83
+
84
+ @SamplerNew.target.setter
85
+ def target(self, value):
86
+ """ Set the target density. Runs validation of the target. """
87
+ # Accept tuple of inputs and construct posterior
88
+ if isinstance(value, tuple) and len(value) == 5:
89
+ # Structure (data, model, L_sqrtprec, P_mean, P_sqrtprec)
90
+ data = value[0]
91
+ model = value[1]
92
+ L_sqrtprec = value[2]
93
+ P_mean = value[3]
94
+ P_sqrtprec = value[4]
95
+
96
+ # If numpy matrix convert to CUQI model
97
+ if isinstance(model, np.ndarray) and len(model.shape) == 2:
98
+ model = cuqi.model.LinearModel(model)
99
+
100
+ # Check model input
101
+ if not isinstance(model, cuqi.model.LinearModel):
102
+ raise TypeError("Model needs to be cuqi.model.LinearModel or matrix")
103
+
104
+ # Likelihood
105
+ L = cuqi.distribution.Gaussian(model, sqrtprec=L_sqrtprec).to_likelihood(data)
106
+
107
+ # Prior TODO: allow multiple priors stacked
108
+ #if isinstance(P_mean, list) and isinstance(P_sqrtprec, list):
109
+ # P = cuqi.distribution.JointGaussianSqrtPrec(P_mean, P_sqrtprec)
110
+ #else:
111
+ P = cuqi.distribution.Gaussian(P_mean, sqrtprec=P_sqrtprec)
112
+
113
+ # Construct posterior
114
+ value = cuqi.distribution.Posterior(L, P)
115
+ super(LinearRTONew, type(self)).target.fset(self, value)
116
+ self._precompute()
117
+
118
+ def _precompute(self):
119
+ L1 = [likelihood.distribution.sqrtprec for likelihood in self.likelihoods]
120
+ L2 = self.prior.sqrtprec
121
+ L2mu = self.prior.sqrtprecTimesMean
122
+
123
+ # pre-computations
124
+ self.n = self.prior.dim
125
+ self.b_tild = np.hstack([L@likelihood.data for (L, likelihood) in zip(L1, self.likelihoods)]+ [L2mu])
126
+
127
+ callability = [callable(likelihood.model) for likelihood in self.likelihoods]
128
+ notcallability = [not c for c in callability]
129
+ if all(notcallability):
130
+ self.M = sp.sparse.vstack([L@likelihood.model for (L, likelihood) in zip(L1, self.likelihoods)] + [L2])
131
+ elif all(callability):
132
+ # in this case, model is a function doing forward and backward operations
133
+ def M(x, flag):
134
+ if flag == 1:
135
+ out1 = [L @ likelihood.model.forward(x) for (L, likelihood) in zip(L1, self.likelihoods)]
136
+ out2 = L2 @ x
137
+ out = np.hstack(out1 + [out2])
138
+ elif flag == 2:
139
+ idx_start = 0
140
+ idx_end = 0
141
+ out1 = np.zeros(self.n)
142
+ for likelihood in self.likelihoods:
143
+ idx_end += len(likelihood.data)
144
+ out1 += likelihood.model.adjoint(likelihood.distribution.sqrtprec.T@x[idx_start:idx_end])
145
+ idx_start = idx_end
146
+ out2 = L2.T @ x[idx_end:]
147
+ out = out1 + out2
148
+ return out
149
+ self.M = M
150
+ else:
151
+ raise TypeError("All likelihoods need to be callable or none need to be callable.")
152
+
153
+ def step(self):
154
+ y = self.b_tild + np.random.randn(len(self.b_tild))
155
+ sim = CGLS(self.M, y, self.current_point, self.maxit, self.tol)
156
+ self.current_point, _ = sim.solve()
157
+ acc = 1
158
+ return acc
159
+
160
+ def tune(self, skip_len, update_count):
161
+ pass
162
+
163
+ def validate_target(self):
164
+ # Check target type
165
+ if not isinstance(self.target, (cuqi.distribution.Posterior, cuqi.distribution.MultipleLikelihoodPosterior)):
166
+ raise ValueError(f"To initialize an object of type {self.__class__}, 'target' need to be of type 'cuqi.distribution.Posterior' or 'cuqi.distribution.MultipleLikelihoodPosterior'.")
167
+
168
+ # Check Linear model and Gaussian likelihood(s)
169
+ if isinstance(self.target, cuqi.distribution.Posterior):
170
+ if not isinstance(self.model, cuqi.model.LinearModel):
171
+ raise TypeError("Model needs to be linear")
172
+
173
+ if not hasattr(self.likelihood.distribution, "sqrtprec"):
174
+ raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")
175
+
176
+ elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior): # Elif used for further alternatives, e.g., stacked posterior
177
+ for likelihood in self.likelihoods:
178
+ if not isinstance(likelihood.model, cuqi.model.LinearModel):
179
+ raise TypeError("Model needs to be linear")
180
+
181
+ if not hasattr(likelihood.distribution, "sqrtprec"):
182
+ raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")
183
+
184
+ # Check Gaussian prior
185
+ if not hasattr(self.prior, "sqrtprec"):
186
+ raise TypeError("prior must contain a sqrtprec attribute")
187
+
188
+ if not hasattr(self.prior, "sqrtprecTimesMean"):
189
+ raise TypeError("Prior must contain a sqrtprecTimesMean attribute")
190
+
191
+ def get_state(self): #TODO: LinearRTO only need initial_point for reproducibility?
192
+ return {'sampler_type': 'LinearRTO'}
193
+
194
+ def set_state(self, state): #TODO: LinearRTO only need initial_point for reproducibility?
195
+ pass
196
+
197
+ class RegularizedLinearRTONew(LinearRTONew):
198
+ """
199
+ Regularized Linear RTO (Randomize-Then-Optimize) sampler.
200
+
201
+ Samples posterior related to the inverse problem with Gaussian likelihood and implicit Gaussian prior, and where the forward model is Linear.
202
+
203
+ Parameters
204
+ ------------
205
+ target : `cuqi.distribution.Posterior`
206
+ See `cuqi.sampler.LinearRTO`
207
+
208
+ initial_point : `np.ndarray`
209
+ Initial point for the sampler. *Optional*.
210
+
211
+ maxit : int
212
+ Maximum number of iterations of the inner FISTA solver. *Optional*.
213
+
214
+ stepsize : string or float
215
+ If stepsize is a string and equals either "automatic", then the stepsize is automatically estimated based on the spectral norm.
216
+ If stepsize is a float, then this stepsize is used.
217
+
218
+ abstol : float
219
+ Absolute tolerance of the inner FISTA solver. *Optional*.
220
+
221
+ adaptive : bool
222
+ If True, FISTA is used as inner solver, otherwise ISTA is used. *Optional*.
223
+
224
+ callback : callable, *Optional*
225
+ If set this function will be called after every sample.
226
+ The signature of the callback function is `callback(sample, sample_index)`,
227
+ where `sample` is the current sample and `sample_index` is the index of the sample.
228
+ An example is shown in demos/demo31_callback.py.
229
+
230
+ """
231
+ def __init__(self, target, initial_point=None, maxit=100, stepsize="automatic", abstol=1e-10, adaptive=True, **kwargs):
232
+
233
+ super().__init__(target=target, initial_point=initial_point, **kwargs)
234
+
235
+ # Other parameters
236
+ self.stepsize = stepsize
237
+ self.abstol = abstol
238
+ self.adaptive = adaptive
239
+ self.proximal = target.prior.proximal
240
+ self._stepsize = self._choose_stepsize()
241
+ self.maxit = maxit
242
+
243
+ @LinearRTONew.target.setter
244
+ def target(self, value):
245
+ if not callable(value.prior.proximal):
246
+ raise TypeError("Projector needs to be callable")
247
+ return super(RegularizedLinearRTONew, type(self)).target.fset(self, value)
248
+
249
+ def _choose_stepsize(self):
250
+ if isinstance(self.stepsize, str):
251
+ if self.stepsize in ["automatic"]:
252
+ if not callable(self.M):
253
+ M_op = scipyLinearOperator(self.M.shape, matvec = lambda v: self.M@v, rmatvec = lambda w: self.M.T@w)
254
+ else:
255
+ M_op = scipyLinearOperator((len(self.b_tild), self.n), matvec = lambda v: self.M(v,1), rmatvec = lambda w: self.M(w,2))
256
+
257
+ _stepsize = 0.99/(estimate_spectral_norm(M_op)**2)
258
+ # print(f"Estimated stepsize for regularized Linear RTO: {_stepsize}")
259
+ else:
260
+ raise ValueError("Stepsize choice not supported")
261
+ else:
262
+ _stepsize = self.stepsize
263
+ return _stepsize
264
+
265
+ @property
266
+ def prior(self):
267
+ return self.target.prior.gaussian
268
+
269
+ def step(self):
270
+ y = self.b_tild + np.random.randn(len(self.b_tild))
271
+ sim = FISTA(self.M, y, self.current_point, self.proximal,
272
+ maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
273
+ self.current_point, _ = sim.solve()
274
+ acc = 1
275
+ return acc
@@ -83,6 +83,11 @@ class SamplerNew(ABC):
83
83
  """ Return the initial point of the sampler. This is always the first sample. """
84
84
  return self._samples[0]
85
85
 
86
+ @initial_point.setter
87
+ def initial_point(self, value):
88
+ """ Set the initial point of the sampler. """
89
+ self._samples[0] = value
90
+
86
91
  @property
87
92
  def dim(self):
88
93
  """ Dimension of the target density. """
@@ -8,13 +8,20 @@ from copy import copy
8
8
  from numbers import Number
9
9
 
10
10
  try:
11
- import arviz # Plotting tool
12
- except ImportError:
11
+ import arviz # Plotting tool
12
+ except ImportError as e:
13
13
  arviz = None
14
+ arviz_import_error = e
15
+
14
16
 
15
17
  def _check_for_arviz():
16
18
  if arviz is None:
17
- raise ImportError("The arviz package is required for this functionality. Please install arviz using `pip install arviz`.")
19
+ msg = "The arviz package is required for this functionality. "\
20
+ + "Please make sure arviz is installed. "\
21
+ + "See below for the original error message:\n"\
22
+ + arviz_import_error.args[0]
23
+
24
+ raise ImportError(msg)
18
25
 
19
26
 
20
27
  class Samples(object):
@@ -1,4 +1,4 @@
1
1
  matplotlib
2
2
  numpy>=1.17.0
3
- scipy
3
+ scipy<1.13
4
4
  arviz