CUQIpy 1.3.0__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/__init__.py +1 -1
  5. cuqi/distribution/_beta.py +1 -1
  6. cuqi/distribution/_cauchy.py +2 -2
  7. cuqi/distribution/_distribution.py +24 -15
  8. cuqi/distribution/_joint_distribution.py +97 -12
  9. cuqi/distribution/_posterior.py +9 -0
  10. cuqi/distribution/_truncated_normal.py +3 -3
  11. cuqi/distribution/_uniform.py +36 -2
  12. cuqi/experimental/__init__.py +1 -1
  13. cuqi/experimental/_recommender.py +216 -0
  14. cuqi/experimental/geometry/_productgeometry.py +3 -3
  15. cuqi/geometry/_geometry.py +12 -1
  16. cuqi/implicitprior/__init__.py +1 -1
  17. cuqi/implicitprior/_regularizedGaussian.py +40 -4
  18. cuqi/implicitprior/_restorator.py +35 -1
  19. cuqi/legacy/__init__.py +2 -0
  20. cuqi/legacy/sampler/__init__.py +11 -0
  21. cuqi/legacy/sampler/_conjugate.py +55 -0
  22. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  23. cuqi/legacy/sampler/_cwmh.py +196 -0
  24. cuqi/legacy/sampler/_gibbs.py +231 -0
  25. cuqi/legacy/sampler/_hmc.py +335 -0
  26. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  27. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  28. cuqi/legacy/sampler/_mh.py +190 -0
  29. cuqi/legacy/sampler/_pcn.py +244 -0
  30. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +134 -152
  31. cuqi/legacy/sampler/_sampler.py +182 -0
  32. cuqi/likelihood/_likelihood.py +1 -1
  33. cuqi/model/_model.py +1248 -357
  34. cuqi/pde/__init__.py +4 -0
  35. cuqi/pde/_observation_map.py +36 -0
  36. cuqi/pde/_pde.py +133 -32
  37. cuqi/problem/_problem.py +88 -82
  38. cuqi/sampler/__init__.py +120 -8
  39. cuqi/sampler/_conjugate.py +376 -35
  40. cuqi/sampler/_conjugate_approx.py +40 -16
  41. cuqi/sampler/_cwmh.py +132 -138
  42. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  43. cuqi/sampler/_gibbs.py +269 -130
  44. cuqi/sampler/_hmc.py +328 -201
  45. cuqi/sampler/_langevin_algorithm.py +282 -98
  46. cuqi/sampler/_laplace_approximation.py +87 -117
  47. cuqi/sampler/_mh.py +47 -157
  48. cuqi/sampler/_pcn.py +56 -211
  49. cuqi/sampler/_rto.py +206 -140
  50. cuqi/sampler/_sampler.py +540 -135
  51. cuqi/solver/_solver.py +6 -2
  52. cuqi/testproblem/_testproblem.py +2 -3
  53. cuqi/utilities/__init__.py +3 -1
  54. cuqi/utilities/_utilities.py +94 -12
  55. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +6 -4
  56. cuqipy-1.4.0.post0.dev61.dist-info/RECORD +102 -0
  57. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +1 -1
  58. CUQIpy-1.3.0.dist-info/RECORD +0 -100
  59. cuqi/experimental/mcmc/__init__.py +0 -123
  60. cuqi/experimental/mcmc/_conjugate.py +0 -345
  61. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  62. cuqi/experimental/mcmc/_cwmh.py +0 -193
  63. cuqi/experimental/mcmc/_gibbs.py +0 -318
  64. cuqi/experimental/mcmc/_hmc.py +0 -464
  65. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -392
  66. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  67. cuqi/experimental/mcmc/_mh.py +0 -80
  68. cuqi/experimental/mcmc/_pcn.py +0 -89
  69. cuqi/experimental/mcmc/_sampler.py +0 -566
  70. cuqi/experimental/mcmc/_utilities.py +0 -17
  71. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info/licenses}/LICENSE +0 -0
  72. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,8 @@ from scipy.linalg.interpolative import estimate_spectral_norm
3
3
  from scipy.sparse.linalg import LinearOperator as scipyLinearOperator
4
4
  import numpy as np
5
5
  import cuqi
6
- from cuqi.solver import CGLS, FISTA, ADMM, ScipyLinearLSQ
7
- from cuqi.experimental.mcmc import Sampler
6
+ from cuqi.solver import CGLS, FISTA
7
+ from cuqi.legacy.sampler import Sampler
8
8
 
9
9
 
10
10
  class LinearRTO(Sampler):
@@ -27,7 +27,7 @@ class LinearRTO(Sampler):
27
27
  P_mean: is the prior mean.
28
28
  P_sqrtprec: is the squareroot of the precision matrix of the Gaussian mean.
29
29
 
30
- initial_point : `np.ndarray`
30
+ x0 : `np.ndarray`
31
31
  Initial point for the sampler. *Optional*.
32
32
 
33
33
  maxit : int
@@ -43,51 +43,60 @@ class LinearRTO(Sampler):
43
43
  An example is shown in demos/demo31_callback.py.
44
44
 
45
45
  """
46
- def __init__(self, target=None, initial_point=None, maxit=10, tol=1e-6, **kwargs):
47
-
48
- super().__init__(target=target, initial_point=initial_point, **kwargs)
46
+ def __init__(self, target, x0=None, maxit=10, tol=1e-6, shift=0, **kwargs):
47
+
48
+ # Accept tuple of inputs and construct posterior
49
+ if isinstance(target, tuple) and len(target) == 5:
50
+ # Structure (data, model, L_sqrtprec, P_mean, P_sqrtprec)
51
+ data = target[0]
52
+ model = target[1]
53
+ L_sqrtprec = target[2]
54
+ P_mean = target[3]
55
+ P_sqrtprec = target[4]
56
+
57
+ # If numpy matrix convert to CUQI model
58
+ if isinstance(model, np.ndarray) and len(model.shape) == 2:
59
+ model = cuqi.model.LinearModel(model)
60
+
61
+ # Check model input
62
+ if not isinstance(model, cuqi.model.AffineModel):
63
+ raise TypeError("Model needs to be cuqi.model.AffineModel or matrix")
64
+
65
+ # Likelihood
66
+ L = cuqi.distribution.Gaussian(model, sqrtprec=L_sqrtprec).to_likelihood(data)
67
+
68
+ # Prior TODO: allow multiple priors stacked
69
+ #if isinstance(P_mean, list) and isinstance(P_sqrtprec, list):
70
+ # P = cuqi.distribution.JointGaussianSqrtPrec(P_mean, P_sqrtprec)
71
+ #else:
72
+ P = cuqi.distribution.Gaussian(P_mean, sqrtprec=P_sqrtprec)
73
+
74
+ # Construct posterior
75
+ target = cuqi.distribution.Posterior(L, P)
76
+
77
+ super().__init__(target, x0=x0, **kwargs)
78
+
79
+ self._check_posterior()
80
+
81
+ # Modify initial guess
82
+ if x0 is not None:
83
+ self.x0 = x0
84
+ else:
85
+ self.x0 = np.zeros(self.prior.dim)
49
86
 
50
87
  # Other parameters
51
88
  self.maxit = maxit
52
- self.tol = tol
53
-
54
- def _initialize(self):
55
- self._precompute()
56
-
57
- @property
58
- def prior(self):
59
- return self.target.prior
60
-
61
- @property
62
- def likelihood(self):
63
- return self.target.likelihood
64
-
65
- @property
66
- def likelihoods(self):
67
- if isinstance(self.target, cuqi.distribution.Posterior):
68
- return [self.target.likelihood]
69
- elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
70
- return self.target.likelihoods
71
-
72
- @property
73
- def model(self):
74
- return self.target.model
75
-
76
- @property
77
- def models(self):
78
- if isinstance(self.target, cuqi.distribution.Posterior):
79
- return [self.target.model]
80
- elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
81
- return self.target.models
82
-
83
- def _precompute(self):
89
+ self.tol = tol
90
+ self.shift = 0
91
+
84
92
  L1 = [likelihood.distribution.sqrtprec for likelihood in self.likelihoods]
85
93
  L2 = self.prior.sqrtprec
86
94
  L2mu = self.prior.sqrtprecTimesMean
87
95
 
88
96
  # pre-computations
89
- self.n = self.prior.dim
90
- self.b_tild = np.hstack([L@(likelihood.data - model._shift) for (L, likelihood, model) in zip(L1, self.likelihoods, self.models)]+ [L2mu]) # With shift from AffineModel
97
+ self.n = len(self.x0)
98
+ self.b_tild = np.hstack([L@(likelihood.data - model._shift) for (L, likelihood, model) in zip(L1, self.likelihoods, self.models)]+ [L2mu])
99
+
91
100
  callability = [callable(likelihood.model) for likelihood in self.likelihoods]
92
101
  notcallability = [not c for c in callability]
93
102
  if all(notcallability):
@@ -105,26 +114,64 @@ class LinearRTO(Sampler):
105
114
  out1 = np.zeros(self.n)
106
115
  for likelihood in self.likelihoods:
107
116
  idx_end += len(likelihood.data)
108
- out1 += likelihood.model._adjoint_func_no_shift(likelihood.distribution.sqrtprec.T@x[idx_start:idx_end])
117
+ out1 += likelihood.model._adjoint_func_no_shift(likelihood.distribution.sqrtprec.T@x[idx_start:idx_end]) # Use adjoint function which excludes shift
109
118
  idx_start = idx_end
110
119
  out2 = L2.T @ x[idx_end:]
111
120
  out = out1 + out2
112
121
  return out
113
122
  self.M = M
114
123
  else:
115
- raise TypeError("All likelihoods need to be callable or none need to be callable.")
124
+ raise TypeError("All likelihoods need to be callable or none need to be callable.")
116
125
 
117
- def step(self):
118
- y = self.b_tild + np.random.randn(len(self.b_tild))
119
- sim = CGLS(self.M, y, self.current_point, self.maxit, self.tol)
120
- self.current_point, _ = sim.solve()
121
- acc = 1
122
- return acc
126
+ @property
127
+ def prior(self):
128
+ return self.target.prior
123
129
 
124
- def tune(self, skip_len, update_count):
125
- pass
130
+ @property
131
+ def likelihood(self):
132
+ return self.target.likelihood
126
133
 
127
- def validate_target(self):
134
+ @property
135
+ def likelihoods(self):
136
+ if isinstance(self.target, cuqi.distribution.Posterior):
137
+ return [self.target.likelihood]
138
+ elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
139
+ return self.target.likelihoods
140
+
141
+ @property
142
+ def model(self):
143
+ return self.target.model
144
+
145
+ @property
146
+ def models(self):
147
+ if isinstance(self.target, cuqi.distribution.Posterior):
148
+ return [self.target.model]
149
+ elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior):
150
+ return self.target.models
151
+
152
+ def _sample(self, N, Nb):
153
+ Ns = N+Nb # number of simulations
154
+ samples = np.empty((self.n, Ns))
155
+
156
+ # initial state
157
+ samples[:, 0] = self.x0
158
+ for s in range(Ns-1):
159
+ y = self.b_tild + np.random.randn(len(self.b_tild))
160
+ sim = CGLS(self.M, y, samples[:, s], self.maxit, self.tol, self.shift)
161
+ samples[:, s+1], _ = sim.solve()
162
+
163
+ self._print_progress(s+2,Ns) #s+2 is the sample number, s+1 is index assuming x0 is the first sample
164
+ self._call_callback(samples[:, s+1], s+1)
165
+
166
+ # remove burn-in
167
+ samples = samples[:, Nb:]
168
+
169
+ return samples, None, None
170
+
171
+ def _sample_adapt(self, N, Nb):
172
+ return self._sample(N,Nb)
173
+
174
+ def _check_posterior(self):
128
175
  # Check target type
129
176
  if not isinstance(self.target, (cuqi.distribution.Posterior, cuqi.distribution.MultipleLikelihoodPosterior)):
130
177
  raise ValueError(f"To initialize an object of type {self.__class__}, 'target' need to be of type 'cuqi.distribution.Posterior' or 'cuqi.distribution.MultipleLikelihoodPosterior'.")
@@ -132,15 +179,15 @@ class LinearRTO(Sampler):
132
179
  # Check Linear model and Gaussian likelihood(s)
133
180
  if isinstance(self.target, cuqi.distribution.Posterior):
134
181
  if not isinstance(self.model, cuqi.model.AffineModel):
135
- raise TypeError("Model needs to be linear or more generally affine")
182
+ raise TypeError("Model needs to be linear or affine")
136
183
 
137
184
  if not hasattr(self.likelihood.distribution, "sqrtprec"):
138
185
  raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")
139
186
 
140
187
  elif isinstance(self.target, cuqi.distribution.MultipleLikelihoodPosterior): # Elif used for further alternatives, e.g., stacked posterior
141
188
  for likelihood in self.likelihoods:
142
- if not isinstance(likelihood.model, cuqi.model.AffineModel):
143
- raise TypeError("Model needs to be linear or more generally affine")
189
+ if not isinstance(likelihood.model, cuqi.model.LinearModel):
190
+ raise TypeError("Model needs to be linear")
144
191
 
145
192
  if not hasattr(likelihood.distribution, "sqrtprec"):
146
193
  raise TypeError("Distribution in Likelihood must contain a sqrtprec attribute")
@@ -151,58 +198,31 @@ class LinearRTO(Sampler):
151
198
 
152
199
  if not hasattr(self.prior, "sqrtprecTimesMean"):
153
200
  raise TypeError("Prior must contain a sqrtprecTimesMean attribute")
154
-
155
- def _get_default_initial_point(self, dim):
156
- """ Get the default initial point for the sampler. Defaults to an array of zeros. """
157
- return np.zeros(dim)
201
+
158
202
 
159
203
  class RegularizedLinearRTO(LinearRTO):
160
204
  """
161
205
  Regularized Linear RTO (Randomize-Then-Optimize) sampler.
162
206
 
163
207
  Samples posterior related to the inverse problem with Gaussian likelihood and implicit Gaussian prior, and where the forward model is Linear.
164
- The sampler works by repeatedly solving regularized linear least squares problems for perturbed data.
165
- The solver for these optimization problems is chosen based on how the regularized is provided in the implicit Gaussian prior.
166
- Currently we use the following solvers:
167
- FISTA: [1] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding algorithm for linear inverse problems." SIAM journal on imaging sciences 2.1 (2009): 183-202.
168
- Used when prior.proximal is callable.
169
- ADMM: [2] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
170
- Used when prior.proximal is a list of penalty terms.
171
- ScipyLinearLSQ: Wrapper for Scipy's lsq_linear for the Trust Region Reflective algorithm. Optionally used when the constraint is either "nonnegativity" or "box".
172
208
 
173
209
  Parameters
174
210
  ------------
175
211
  target : `cuqi.distribution.Posterior`
176
- See `cuqi.sampler.LinearRTO`
212
+ See `cuqi.legacy.sampler.LinearRTO`
177
213
 
178
- initial_point : `np.ndarray`
214
+ x0 : `np.ndarray`
179
215
  Initial point for the sampler. *Optional*.
180
216
 
181
217
  maxit : int
182
- Maximum number of iterations of the FISTA/ADMM/ScipyLinearLSQ solver. *Optional*.
183
-
184
- inner_max_it : int
185
- Maximum number of iterations of the CGLS solver used within the ADMM solver. *Optional*.
218
+ Maximum number of iterations of the inner FISTA solver. *Optional*.
186
219
 
187
220
  stepsize : string or float
188
221
  If stepsize is a string and equals either "automatic", then the stepsize is automatically estimated based on the spectral norm.
189
222
  If stepsize is a float, then this stepsize is used.
190
223
 
191
- penalty_parameter : int
192
- Penalty parameter of the ADMM solver. *Optional*.
193
- See [2] or `cuqi.solver.ADMM`
194
-
195
224
  abstol : float
196
- Absolute tolerance of the FISTA/ScipyLinearLSQ solver. *Optional*.
197
-
198
- inner_abstol : float
199
- Tolerance parameter for ScipyLinearLSQ's inner solve of the unbounded least-squares problem. *Optional*.
200
-
201
- adaptive : bool
202
- If True, FISTA is used as solver, otherwise ISTA is used. *Optional*.
203
-
204
- solver : string
205
- If set to "ScipyLinearLSQ", solver is set to cuqi.solver.ScipyLinearLSQ, otherwise FISTA/ISTA or ADMM is used. Note "ScipyLinearLSQ" can only be used with `RegularizedGaussian` of `box` or `nonnegativity` constraint. *Optional*.
225
+ Absolute tolerance of the inner FISTA solver. *Optional*.
206
226
 
207
227
  callback : callable, *Optional*
208
228
  If set this function will be called after every sample.
@@ -211,51 +231,27 @@ class RegularizedLinearRTO(LinearRTO):
211
231
  An example is shown in demos/demo31_callback.py.
212
232
 
213
233
  """
214
- def __init__(self, target=None, initial_point=None, maxit=100, inner_max_it=10, stepsize="automatic", penalty_parameter=10, abstol=1e-10, adaptive=True, solver=None, inner_abstol=None, **kwargs):
234
+ def __init__(self, target, x0=None, maxit=100, stepsize = "automatic", abstol=1e-10, adaptive = True, **kwargs):
235
+
236
+ if not callable(target.prior.proximal):
237
+ raise TypeError("Projector needs to be callable")
215
238
 
216
- super().__init__(target=target, initial_point=initial_point, **kwargs)
239
+ super().__init__(target, x0=x0, maxit=100, **kwargs)
217
240
 
218
241
  # Other parameters
219
242
  self.stepsize = stepsize
220
- self.abstol = abstol
221
- self.inner_abstol = inner_abstol
243
+ self.abstol = abstol
222
244
  self.adaptive = adaptive
223
- self.maxit = maxit
224
- self.inner_max_it = inner_max_it
225
- self.penalty_parameter = penalty_parameter
226
- self.solver = solver
227
-
228
- def _initialize(self):
229
- super()._initialize()
230
- if self.solver is None:
231
- self.solver = "FISTA" if callable(self.proximal) else "ADMM"
232
- if self.solver == "FISTA":
233
- self._stepsize = self._choose_stepsize()
245
+ self.proximal = target.prior.proximal
234
246
 
235
247
  @property
236
- def solver(self):
237
- return self._solver
238
-
239
- @solver.setter
240
- def solver(self, value):
241
- if value == "ScipyLinearLSQ":
242
- if (self.target.prior.preset["constraint"] == "nonnegativity" or self.target.prior.preset["constraint"] == "box"):
243
- self._solver = value
244
- else:
245
- raise ValueError("ScipyLinearLSQ only supports RegularizedGaussian with box or nonnegativity constraint.")
246
- else:
247
- self._solver = value
248
-
249
- @property
250
- def proximal(self):
251
- return self.target.prior.proximal
252
-
253
- def validate_target(self):
254
- super().validate_target()
255
- if not isinstance(self.target.prior, (cuqi.implicitprior.RegularizedGaussian, cuqi.implicitprior.RegularizedGMRF)):
256
- raise TypeError("Prior needs to be RegularizedGaussian or RegularizedGMRF")
248
+ def prior(self):
249
+ return self.target.prior.gaussian
257
250
 
258
- def _choose_stepsize(self):
251
+ def _sample(self, N, Nb):
252
+ Ns = N+Nb # number of simulations
253
+ samples = np.empty((self.n, Ns))
254
+
259
255
  if isinstance(self.stepsize, str):
260
256
  if self.stepsize in ["automatic"]:
261
257
  if not callable(self.M):
@@ -269,34 +265,20 @@ class RegularizedLinearRTO(LinearRTO):
269
265
  raise ValueError("Stepsize choice not supported")
270
266
  else:
271
267
  _stepsize = self.stepsize
272
- return _stepsize
273
-
274
- @property
275
- def prior(self):
276
- return self.target.prior.gaussian
277
-
278
- def step(self):
279
- y = self.b_tild + np.random.randn(len(self.b_tild))
280
-
281
- if self.solver == "FISTA":
268
+
269
+ # initial state
270
+ samples[:, 0] = self.x0
271
+ for s in range(Ns-1):
272
+ y = self.b_tild + np.random.randn(len(self.b_tild))
282
273
  sim = FISTA(self.M, y, self.proximal,
283
- self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
284
- elif self.solver == "ADMM":
285
- sim = ADMM(self.M, y, self.proximal,
286
- self.current_point, self.penalty_parameter, maxit = self.maxit, inner_max_it = self.inner_max_it, adaptive = self.adaptive)
287
- elif self.solver == "ScipyLinearLSQ":
288
- A_op = sp.sparse.linalg.LinearOperator((sum([llh.dim for llh in self.likelihoods])+self.target.prior.dim, self.target.prior.dim),
289
- matvec=lambda x: self.M(x, 1),
290
- rmatvec=lambda x: self.M(x, 2)
291
- )
292
- sim = ScipyLinearLSQ(A_op, y, self.target.prior._box_bounds,
293
- max_iter = self.maxit,
294
- lsmr_maxiter = self.inner_max_it,
295
- tol = self.abstol,
296
- lsmr_tol = self.inner_abstol)
297
- else:
298
- raise ValueError("Choice of solver not supported.")
274
+ samples[:, s], maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
275
+ samples[:, s+1], _ = sim.solve()
276
+
277
+ self._print_progress(s+2,Ns) #s+2 is the sample number, s+1 is index assuming x0 is the first sample
278
+ self._call_callback(samples[:, s+1], s+1)
279
+ # remove burn-in
280
+ samples = samples[:, Nb:]
281
+
282
+ return samples, None, None
283
+
299
284
 
300
- self.current_point, _ = sim.solve()
301
- acc = 1
302
- return acc
@@ -0,0 +1,182 @@
1
+ from abc import ABC, abstractmethod
2
+ import sys
3
+ import numpy as np
4
+ import cuqi
5
+ from cuqi.samples import Samples
6
+ import warnings
7
+
8
+ class Sampler(ABC):
9
+
10
+ def __init__(self, target, x0=None, dim=None, callback=None):
11
+
12
+ warnings.warn(f"\nYou are using the legacy sampler '{self.__class__.__name__}'.\n"
13
+ f"This will be removed in a future release of CUQIpy.\n"
14
+ f"Please consider using the new samplers in the 'cuqi.sampler' module.\n", UserWarning, stacklevel=2)
15
+
16
+ self._dim = dim
17
+ if hasattr(target,'dim'):
18
+ if self._dim is None:
19
+ self._dim = target.dim
20
+ elif self._dim != target.dim:
21
+ raise ValueError("'dim' need to be None or equal to 'target.dim'")
22
+ elif x0 is not None:
23
+ self._dim = len(x0)
24
+
25
+ self.target = target
26
+
27
+ if x0 is None:
28
+ x0 = np.ones(self.dim)
29
+ self.x0 = x0
30
+
31
+ self.callback = callback
32
+
33
+ def step(self, x):
34
+ """
35
+ Perform a single MCMC step
36
+ """
37
+ # Currently a hack to get step method for any sampler
38
+ self.x0 = x
39
+ return self.sample(2).samples[:,-1]
40
+
41
+ def step_tune(self, x, *args, **kwargs):
42
+ """
43
+ Perform a single MCMC step and tune the sampler. This is used during burn-in.
44
+ """
45
+ # Currently a hack to get step method for any sampler
46
+ out = self.step(x)
47
+ self.tune(*args, *kwargs)
48
+ return out
49
+
50
+ def tune(self):
51
+ """
52
+ Tune the sampler parameters.
53
+ """
54
+ pass
55
+
56
+
57
+ @property
58
+ def geometry(self):
59
+ if hasattr(self, 'target') and hasattr(self.target, 'geometry'):
60
+ geom = self.target.geometry
61
+ else:
62
+ geom = cuqi.geometry._DefaultGeometry1D(self.dim)
63
+ return geom
64
+
65
+ @property
66
+ def target(self):
67
+ return self._target
68
+
69
+ @target.setter
70
+ def target(self, value):
71
+ if not isinstance(value, cuqi.distribution.Distribution) and callable(value):
72
+ # obtain self.dim
73
+ if self.dim is not None:
74
+ dim = self.dim
75
+ else:
76
+ raise ValueError(f"If 'target' is a lambda function, the parameter 'dim' need to be specified when initializing {self.__class__}.")
77
+
78
+ # set target
79
+ self._target = cuqi.distribution.UserDefinedDistribution(logpdf_func=value, dim = dim)
80
+
81
+ elif isinstance(value, cuqi.distribution.Distribution):
82
+ self._target = value
83
+ else:
84
+ raise ValueError("'target' need to be either a lambda function or of type 'cuqi.distribution.Distribution'")
85
+
86
+
87
+ @property
88
+ def dim(self):
89
+ if hasattr(self,'target') and hasattr(self.target,'dim'):
90
+ self._dim = self.target.dim
91
+ return self._dim
92
+
93
+
94
+ def sample(self,N,Nb=0):
95
+ # Get samples from the samplers sample method
96
+ result = self._sample(N,Nb)
97
+ return self._create_Sample_object(result,N+Nb)
98
+
99
+ def sample_adapt(self,N,Nb=0):
100
+ # Get samples from the samplers sample method
101
+ result = self._sample_adapt(N,Nb)
102
+ return self._create_Sample_object(result,N+Nb)
103
+
104
+ def _create_Sample_object(self,result,N):
105
+ loglike_eval = None
106
+ acc_rate = None
107
+ if isinstance(result,tuple):
108
+ #Unpack samples+loglike+acc_rate
109
+ s = result[0]
110
+ if len(result)>1: loglike_eval = result[1]
111
+ if len(result)>2: acc_rate = result[2]
112
+ if len(result)>3: raise TypeError("Expected tuple of at most 3 elements from sampling method.")
113
+ else:
114
+ s = result
115
+
116
+ #Store samples in cuqi samples object if more than 1 sample
117
+ if N==1:
118
+ if len(s) == 1 and isinstance(s,np.ndarray): #Extract single value from numpy array
119
+ s = s.ravel()[0]
120
+ else:
121
+ s = s.flatten()
122
+ else:
123
+ s = Samples(s, self.geometry)#, geometry = self.geometry)
124
+ s.loglike_eval = loglike_eval
125
+ s.acc_rate = acc_rate
126
+ return s
127
+
128
+ @abstractmethod
129
+ def _sample(self,N,Nb):
130
+ pass
131
+
132
+ @abstractmethod
133
+ def _sample_adapt(self,N,Nb):
134
+ pass
135
+
136
+ def _print_progress(self,s,Ns):
137
+ """Prints sampling progress"""
138
+ if Ns > 2:
139
+ if (s % (max(Ns//100,1))) == 0:
140
+ msg = f'Sample {s} / {Ns}'
141
+ sys.stdout.write('\r'+msg)
142
+ if s==Ns:
143
+ msg = f'Sample {s} / {Ns}'
144
+ sys.stdout.write('\r'+msg+'\n')
145
+
146
+ def _call_callback(self, sample, sample_index):
147
+ """ Calls the callback function. Assumes input is sample and sample index"""
148
+ if self.callback is not None:
149
+ self.callback(sample, sample_index)
150
+
151
+ class ProposalBasedSampler(Sampler,ABC):
152
+ def __init__(self, target, proposal=None, scale=1, x0=None, dim=None, **kwargs):
153
+ #TODO: after fixing None dim
154
+ #if dim is None and hasattr(proposal,'dim'):
155
+ # dim = proposal.dim
156
+ super().__init__(target, x0=x0, dim=dim, **kwargs)
157
+
158
+ self.proposal =proposal
159
+ self.scale = scale
160
+
161
+
162
+ @property
163
+ def proposal(self):
164
+ return self._proposal
165
+
166
+ @proposal.setter
167
+ def proposal(self, value):
168
+ self._proposal = value
169
+
170
+ @property
171
+ def geometry(self):
172
+ geom1, geom2 = None, None
173
+ if hasattr(self, 'proposal') and hasattr(self.proposal, 'geometry') and self.proposal.geometry.par_dim is not None:
174
+ geom1= self.proposal.geometry
175
+ if hasattr(self, 'target') and hasattr(self.target, 'geometry') and self.target.geometry.par_dim is not None:
176
+ geom2 = self.target.geometry
177
+ if not isinstance(geom1,cuqi.geometry._DefaultGeometry) and geom1 is not None:
178
+ return geom1
179
+ elif not isinstance(geom2,cuqi.geometry._DefaultGeometry) and geom2 is not None:
180
+ return geom2
181
+ else:
182
+ return cuqi.geometry._DefaultGeometry1D(self.dim)
@@ -212,4 +212,4 @@ class UserDefinedLikelihood(object):
212
212
  return get_non_default_args(self.logpdf_func)
213
213
 
214
214
  def __repr__(self) -> str:
215
- return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())
215
+ return "CUQI {} function. Parameters {}.".format(self.__class__.__name__,self.get_parameter_names())