CUQIpy 1.1.1.post0.dev36__py3-none-any.whl → 1.4.1.post0.dev124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (92) hide show
  1. cuqi/__init__.py +2 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/algebra/__init__.py +2 -0
  4. cuqi/algebra/_abstract_syntax_tree.py +358 -0
  5. cuqi/algebra/_ordered_set.py +82 -0
  6. cuqi/algebra/_random_variable.py +457 -0
  7. cuqi/array/_array.py +4 -13
  8. cuqi/config.py +7 -0
  9. cuqi/density/_density.py +9 -1
  10. cuqi/distribution/__init__.py +3 -2
  11. cuqi/distribution/_beta.py +7 -11
  12. cuqi/distribution/_cauchy.py +2 -2
  13. cuqi/distribution/_custom.py +0 -6
  14. cuqi/distribution/_distribution.py +31 -45
  15. cuqi/distribution/_gamma.py +7 -3
  16. cuqi/distribution/_gaussian.py +2 -12
  17. cuqi/distribution/_inverse_gamma.py +4 -10
  18. cuqi/distribution/_joint_distribution.py +112 -15
  19. cuqi/distribution/_lognormal.py +0 -7
  20. cuqi/distribution/{_modifiedhalfnormal.py → _modified_half_normal.py} +23 -23
  21. cuqi/distribution/_normal.py +34 -7
  22. cuqi/distribution/_posterior.py +9 -0
  23. cuqi/distribution/_truncated_normal.py +129 -0
  24. cuqi/distribution/_uniform.py +47 -1
  25. cuqi/experimental/__init__.py +2 -2
  26. cuqi/experimental/_recommender.py +216 -0
  27. cuqi/geometry/__init__.py +2 -0
  28. cuqi/geometry/_geometry.py +15 -1
  29. cuqi/geometry/_product_geometry.py +181 -0
  30. cuqi/implicitprior/__init__.py +5 -3
  31. cuqi/implicitprior/_regularized_gaussian.py +483 -0
  32. cuqi/implicitprior/{_regularizedGMRF.py → _regularized_gmrf.py} +4 -2
  33. cuqi/implicitprior/{_regularizedUnboundedUniform.py → _regularized_unbounded_uniform.py} +3 -2
  34. cuqi/implicitprior/_restorator.py +269 -0
  35. cuqi/legacy/__init__.py +2 -0
  36. cuqi/{experimental/mcmc → legacy/sampler}/__init__.py +7 -11
  37. cuqi/legacy/sampler/_conjugate.py +55 -0
  38. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  39. cuqi/legacy/sampler/_cwmh.py +196 -0
  40. cuqi/legacy/sampler/_gibbs.py +231 -0
  41. cuqi/legacy/sampler/_hmc.py +335 -0
  42. cuqi/{experimental/mcmc → legacy/sampler}/_langevin_algorithm.py +82 -111
  43. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  44. cuqi/legacy/sampler/_mh.py +190 -0
  45. cuqi/legacy/sampler/_pcn.py +244 -0
  46. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +132 -90
  47. cuqi/legacy/sampler/_sampler.py +182 -0
  48. cuqi/likelihood/_likelihood.py +9 -1
  49. cuqi/model/__init__.py +1 -1
  50. cuqi/model/_model.py +1361 -359
  51. cuqi/pde/__init__.py +4 -0
  52. cuqi/pde/_observation_map.py +36 -0
  53. cuqi/pde/_pde.py +134 -33
  54. cuqi/problem/_problem.py +93 -87
  55. cuqi/sampler/__init__.py +120 -8
  56. cuqi/sampler/_conjugate.py +376 -35
  57. cuqi/sampler/_conjugate_approx.py +40 -16
  58. cuqi/sampler/_cwmh.py +132 -138
  59. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  60. cuqi/sampler/_gibbs.py +288 -130
  61. cuqi/sampler/_hmc.py +328 -201
  62. cuqi/sampler/_langevin_algorithm.py +284 -100
  63. cuqi/sampler/_laplace_approximation.py +87 -117
  64. cuqi/sampler/_mh.py +47 -157
  65. cuqi/sampler/_pcn.py +65 -213
  66. cuqi/sampler/_rto.py +211 -142
  67. cuqi/sampler/_sampler.py +553 -136
  68. cuqi/samples/__init__.py +1 -1
  69. cuqi/samples/_samples.py +24 -18
  70. cuqi/solver/__init__.py +6 -4
  71. cuqi/solver/_solver.py +230 -26
  72. cuqi/testproblem/_testproblem.py +2 -3
  73. cuqi/utilities/__init__.py +6 -1
  74. cuqi/utilities/_get_python_variable_name.py +2 -2
  75. cuqi/utilities/_utilities.py +182 -2
  76. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/METADATA +10 -6
  77. cuqipy-1.4.1.post0.dev124.dist-info/RECORD +101 -0
  78. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/WHEEL +1 -1
  79. CUQIpy-1.1.1.post0.dev36.dist-info/RECORD +0 -92
  80. cuqi/experimental/mcmc/_conjugate.py +0 -197
  81. cuqi/experimental/mcmc/_conjugate_approx.py +0 -81
  82. cuqi/experimental/mcmc/_cwmh.py +0 -191
  83. cuqi/experimental/mcmc/_gibbs.py +0 -268
  84. cuqi/experimental/mcmc/_hmc.py +0 -470
  85. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  86. cuqi/experimental/mcmc/_mh.py +0 -78
  87. cuqi/experimental/mcmc/_pcn.py +0 -89
  88. cuqi/experimental/mcmc/_sampler.py +0 -561
  89. cuqi/experimental/mcmc/_utilities.py +0 -17
  90. cuqi/implicitprior/_regularizedGaussian.py +0 -323
  91. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info/licenses}/LICENSE +0 -0
  92. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/top_level.txt +0 -0
@@ -1,81 +0,0 @@
1
- import numpy as np
2
- from cuqi.experimental.mcmc import Conjugate
3
- from cuqi.experimental.mcmc._conjugate import _ConjugatePair, _get_conjugate_parameter, _check_conjugate_parameter_is_scalar_reciprocal
4
- from cuqi.distribution import LMRF, Gamma
5
- import scipy as sp
6
-
7
- class ConjugateApprox(Conjugate):
8
- """ Approximate Conjugate sampler
9
-
10
- Sampler for sampling a posterior distribution where the likelihood and prior can be approximated
11
- by a conjugate pair.
12
-
13
- Currently supported pairs are:
14
- - (LMRF, Gamma): Approximated by (Gaussian, Gamma) where Gamma is defined on the inverse of the scale parameter of the LMRF distribution.
15
-
16
- Gamma distribution must be univariate.
17
-
18
- LMRF likelihood must have zero mean.
19
-
20
- For more details on conjugacy see :class:`Conjugate`.
21
-
22
- """
23
-
24
- def _set_conjugatepair(self):
25
- """ Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
26
- if isinstance(self.target.likelihood.distribution, LMRF) and isinstance(self.target.prior, Gamma):
27
- self._conjugatepair = _LMRFGammaPair(self.target)
28
- else:
29
- raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
30
-
31
-
32
- class _LMRFGammaPair(_ConjugatePair):
33
- """ Implementation of the conjugate pair (LMRF, Gamma) """
34
-
35
- def validate_target(self):
36
- if not isinstance(self.target.likelihood.distribution, LMRF):
37
- raise ValueError("Approximate conjugate sampler only works with LMRF likelihood function")
38
-
39
- if not isinstance(self.target.prior, Gamma):
40
- raise ValueError("Approximate conjugate sampler with LMRF likelihood only works with Gamma prior")
41
-
42
- if not self.target.prior.dim == 1:
43
- raise ValueError("Approximate conjugate sampler only works with univariate Gamma prior")
44
-
45
- if np.sum(self.target.likelihood.distribution.location) != 0:
46
- raise ValueError("Approximate conjugate sampler only works with zero mean LMRF likelihood")
47
-
48
- key, value = _get_conjugate_parameter(self.target)
49
- if key == "scale":
50
- if not _check_conjugate_parameter_is_scalar_reciprocal(value):
51
- raise ValueError("Approximate conjugate sampler only works with Gamma prior on the inverse of the scale parameter of the LMRF likelihood")
52
- else:
53
- raise ValueError(f"No approximate conjugacy defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
54
-
55
- def sample(self):
56
- # Extract variables
57
- # Here we approximate the LMRF with a Gaussian
58
-
59
- # Extract diff_op from target likelihood
60
- D = self.target.likelihood.distribution._diff_op
61
- n = D.shape[0]
62
-
63
- # Gaussian approximation of LMRF prior as function of x_k
64
- # See Uribe et al. (2022) for details
65
- # Current has a zero mean assumption on likelihood! TODO
66
- beta=1e-5
67
- def Lk_fun(x_k):
68
- dd = 1/np.sqrt((D @ x_k)**2 + beta*np.ones(n))
69
- W = sp.sparse.diags(dd)
70
- return W.sqrt() @ D
71
-
72
- x = self.target.likelihood.data #x
73
- d = len(x) #d
74
- Lx = Lk_fun(x)@x #Lx
75
- alpha = self.target.prior.shape #alpha
76
- beta = self.target.prior.rate #beta
77
-
78
- # Create Gamma distribution and sample
79
- dist = Gamma(shape=d+alpha, rate=np.linalg.norm(Lx)**2+beta)
80
-
81
- return dist.sample()
@@ -1,191 +0,0 @@
1
- import numpy as np
2
- import cuqi
3
- from cuqi.experimental.mcmc import ProposalBasedSampler
4
- from cuqi.array import CUQIarray
5
- from numbers import Number
6
-
7
- class CWMH(ProposalBasedSampler):
8
- """Component-wise Metropolis Hastings sampler.
9
-
10
- Allows sampling of a target distribution by a component-wise random-walk
11
- sampling of a proposal distribution along with an accept/reject step.
12
-
13
- Parameters
14
- ----------
15
-
16
- target : `cuqi.distribution.Distribution` or lambda function
17
- The target distribution to sample. Custom logpdfs are supported by using
18
- a :class:`cuqi.distribution.UserDefinedDistribution`.
19
-
20
- proposal : `cuqi.distribution.Distribution` or callable method
21
- The proposal to sample from. If a callable method it should provide a
22
- single independent sample from proposal distribution. Defaults to a
23
- Gaussian proposal. *Optional*.
24
-
25
- scale : float or ndarray
26
- Scale parameter used to define correlation between previous and proposed
27
- sample in random-walk. *Optional*. If float, the same scale is used for
28
- all dimensions. If ndarray, a (possibly) different scale is used for
29
- each dimension.
30
-
31
- initial_point : ndarray
32
- Initial parameters. *Optional*
33
-
34
- callback : callable, *Optional*
35
- If set this function will be called after every sample.
36
- The signature of the callback function is
37
- `callback(sample, sample_index)`, where `sample` is the current sample
38
- and `sample_index` is the index of the sample.
39
- An example is shown in demos/demo31_callback.py.
40
-
41
- kwargs : dict
42
- Additional keyword arguments to be passed to the base class
43
- :class:`ProposalBasedSampler`.
44
-
45
- Example
46
- -------
47
- .. code-block:: python
48
- import numpy as np
49
- import cuqi
50
- # Parameters
51
- dim = 5 # Dimension of distribution
52
- mu = np.arange(dim) # Mean of Gaussian
53
- std = 1 # standard deviation of Gaussian
54
-
55
- # Logpdf function
56
- logpdf_func = lambda x: -1/(std**2)*np.sum((x-mu)**2)
57
-
58
- # Define distribution from logpdf as UserDefinedDistribution (sample
59
- # and gradients also supported as inputs to UserDefinedDistribution)
60
- target = cuqi.distribution.UserDefinedDistribution(
61
- dim=dim, logpdf_func=logpdf_func)
62
-
63
- # Set up sampler
64
- sampler = cuqi.experimental.mcmc.CWMH(target, scale=1)
65
-
66
- # Sample
67
- samples = sampler.sample(2000).get_samples()
68
-
69
- """
70
-
71
- _STATE_KEYS = ProposalBasedSampler._STATE_KEYS.union(['_scale_temp'])
72
-
73
- def __init__(self, target:cuqi.density.Density=None, proposal=None, scale=1,
74
- initial_point=None, **kwargs):
75
- super().__init__(target, proposal=proposal, scale=scale,
76
- initial_point=initial_point, **kwargs)
77
-
78
- def _initialize(self):
79
- if isinstance(self.scale, Number):
80
- self.scale = np.ones(self.dim)*self.scale
81
- self._acc = [np.ones((self.dim))] # Overwrite acc from ProposalBasedSampler with list of arrays
82
-
83
- # Handling of temporary scale parameter due to possible bug in old CWMH
84
- self._scale_temp = self.scale.copy()
85
-
86
- @property
87
- def scale(self):
88
- """ Get the scale parameter. """
89
- return self._scale
90
-
91
- @scale.setter
92
- def scale(self, value):
93
- """ Set the scale parameter. """
94
- if self._is_initialized and isinstance(value, Number):
95
- value = np.ones(self.dim)*value
96
- self._scale = value
97
-
98
- def validate_target(self):
99
- if not isinstance(self.target, cuqi.density.Density):
100
- raise ValueError(
101
- "Target should be an instance of "+\
102
- f"{cuqi.density.Density.__class__.__name__}")
103
- # Fail when there is no log density, which is currently assumed to be the case in case NaN is returned.
104
- if np.isnan(self.target.logd(self._get_default_initial_point(self.dim))):
105
- raise ValueError("Target does not have valid logd")
106
-
107
- def validate_proposal(self):
108
- if not isinstance(self.proposal, cuqi.distribution.Distribution):
109
- raise ValueError("Proposal must be a cuqi.distribution.Distribution object")
110
- if not self.proposal.is_symmetric:
111
- raise ValueError("Proposal must be symmetric")
112
-
113
- @property
114
- def proposal(self):
115
- if self._proposal is None:
116
- self._proposal = cuqi.distribution.Normal(
117
- mean=lambda location: location,
118
- std=lambda scale: scale,
119
- geometry=self.dim,
120
- )
121
- return self._proposal
122
-
123
- @proposal.setter
124
- def proposal(self, value):
125
- self._proposal = value
126
-
127
- def step(self):
128
- # Initialize x_t which is used to store the current CWMH sample
129
- x_t = self.current_point.copy()
130
-
131
- # Initialize x_star which is used to store the proposed sample by
132
- # updating the current sample component-by-component
133
- x_star = self.current_point.copy()
134
-
135
- # Propose a sample x_all_components from the proposal distribution
136
- # for all the components
137
- target_eval_t = self.current_target_logd
138
- if isinstance(self.proposal,cuqi.distribution.Distribution):
139
- x_all_components = self.proposal(
140
- location= self.current_point, scale=self.scale).sample()
141
- else:
142
- x_all_components = self.proposal(self.current_point, self.scale)
143
-
144
- # Initialize acceptance rate
145
- acc = np.zeros(self.dim)
146
-
147
- # Loop over all the components of the sample and accept/reject
148
- # each component update.
149
- for j in range(self.dim):
150
- # propose state x_star by updating the j-th component
151
- x_star[j] = x_all_components[j]
152
-
153
- # evaluate target
154
- target_eval_star = self.target.logd(x_star)
155
-
156
- # compute Metropolis acceptance ratio
157
- alpha = min(0, target_eval_star - target_eval_t)
158
-
159
- # accept/reject
160
- u_theta = np.log(np.random.rand())
161
- if (u_theta <= alpha): # accept
162
- x_t[j] = x_all_components[j]
163
- target_eval_t = target_eval_star
164
- acc[j] = 1
165
-
166
- x_star = x_t.copy()
167
-
168
- self.current_target_logd = target_eval_t
169
- self.current_point = x_t
170
-
171
- return acc
172
-
173
- def tune(self, skip_len, update_count):
174
- # Store update_count in variable i for readability
175
- i = update_count
176
-
177
- # Optimal acceptance rate for CWMH
178
- star_acc = 0.21/self.dim + 0.23
179
-
180
- # Mean of acceptance rate over the last skip_len samples
181
- hat_acc = np.mean(self._acc[i*skip_len:(i+1)*skip_len], axis=0)
182
-
183
- # Compute new intermediate scaling parameter scale_temp
184
- # Factor zeta ensures that the variation of the scale update vanishes
185
- zeta = 1/np.sqrt(update_count+1)
186
- scale_temp = np.exp(
187
- np.log(self._scale_temp) + zeta*(hat_acc-star_acc))
188
-
189
- # Update the scale parameter
190
- self.scale = np.minimum(scale_temp, np.ones(self.dim))
191
- self._scale_temp = scale_temp
@@ -1,268 +0,0 @@
1
- from cuqi.distribution import JointDistribution
2
- from cuqi.experimental.mcmc import Sampler
3
- from cuqi.samples import Samples
4
- from typing import Dict
5
- import numpy as np
6
- import warnings
7
-
8
- try:
9
- from tqdm import tqdm
10
- except ImportError:
11
- def tqdm(iterable, **kwargs):
12
- warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
13
- return iterable
14
-
15
- # Not subclassed from Sampler as Gibbs handles multiple samplers and samples multiple parameters
16
- # Similar approach as for JointDistribution
17
- class HybridGibbs:
18
- """
19
- Hybrid Gibbs sampler for sampling a joint distribution.
20
-
21
- Gibbs sampling samples the variables of the distribution sequentially,
22
- one variable at a time. When a variable represents a random vector, the
23
- whole vector is sampled simultaneously.
24
-
25
- The sampling of each variable is done by sampling from the conditional
26
- distribution of that variable given the values of the other variables.
27
- This is often a very efficient way of sampling from a joint distribution
28
- if the conditional distributions are easy to sample from.
29
-
30
- Hybrid Gibbs sampler is a generalization of the Gibbs sampler where the
31
- conditional distributions are sampled using different MCMC samplers.
32
-
33
- When the conditionals are sampled exactly, the samples from the Gibbs
34
- sampler converge to the joint distribution. See e.g.
35
- Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
36
- for more details.
37
-
38
- In each Gibbs step, the corresponding sampler has the initial_point
39
- and initial_scale (if applicable) set to the value of the previous step
40
- and the sampler is reinitialized. This means that the sampling is not
41
- fully stateful at this point. This means samplers like NUTS will lose
42
- their internal state between Gibbs steps.
43
-
44
- Parameters
45
- ----------
46
- target : cuqi.distribution.JointDistribution
47
- Target distribution to sample from.
48
-
49
- sampling_strategy : dict
50
- Dictionary of sampling strategies for each variable.
51
- Keys are variable names.
52
- Values are sampler objects.
53
-
54
- num_sampling_steps : dict, *optional*
55
- Dictionary of number of sampling steps for each variable.
56
- The sampling steps are defined as the number of times the sampler
57
- will call its step method in each Gibbs step.
58
- Default is 1 for all variables.
59
-
60
- Example
61
- -------
62
- .. code-block:: python
63
-
64
- import cuqi
65
- import numpy as np
66
-
67
- # Model and data
68
- A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='square').get_components()
69
- n = A.domain_dim
70
-
71
- # Define distributions
72
- d = cuqi.distribution.Gamma(1, 1e-4)
73
- l = cuqi.distribution.Gamma(1, 1e-4)
74
- x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
75
- y = cuqi.distribution.Gaussian(A, lambda l: 1/l)
76
-
77
- # Combine into a joint distribution and create posterior
78
- joint = cuqi.distribution.JointDistribution(d, l, x, y)
79
- posterior = joint(y=y_obs)
80
-
81
- # Define sampling strategy
82
- sampling_strategy = {
83
- 'x': cuqi.experimental.mcmc.LinearRTO(maxit=15),
84
- 'd': cuqi.experimental.mcmc.Conjugate(),
85
- 'l': cuqi.experimental.mcmc.Conjugate(),
86
- }
87
-
88
- # Define Gibbs sampler
89
- sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
90
-
91
- # Run sampler
92
- samples = sampler.sample(Ns=1000, Nb=200)
93
-
94
- # Plot results
95
- samples['x'].plot_ci(exact=probinfo.exactSolution)
96
- samples['d'].plot_trace(figsize=(8,2))
97
- samples['l'].plot_trace(figsize=(8,2))
98
-
99
- """
100
-
101
- def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None):
102
-
103
- # Store target and allow conditioning to reduce to a single density
104
- self.target = target() # Create a copy of target distribution (to avoid modifying the original)
105
-
106
- # Store sampler instances (again as a copy to avoid modifying the original)
107
- self.samplers = sampling_strategy.copy()
108
-
109
- # Store number of sampling steps for each parameter
110
- self.num_sampling_steps = num_sampling_steps
111
-
112
- # Store parameter names
113
- self.par_names = self.target.get_parameter_names()
114
-
115
- # Initialize sampler (after target is set)
116
- self._initialize()
117
-
118
- def _initialize(self):
119
- """ Initialize sampler """
120
-
121
- # Initial points
122
- self.current_samples = self._get_initial_points()
123
-
124
- # Initialize sampling steps
125
- self._initialize_num_sampling_steps()
126
-
127
- # Allocate samples
128
- self._allocate_samples()
129
-
130
- # Set targets
131
- self._set_targets()
132
-
133
- # Initialize the samplers
134
- self._initialize_samplers()
135
-
136
- # Run over pre-sample methods for samplers that have it
137
- # TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
138
- # This is not ideal and should be fixed in the future
139
- for sampler in self.samplers.values():
140
- self._pre_warmup_and_pre_sample_sampler(sampler)
141
-
142
- # Validate all targets for samplers.
143
- self.validate_targets()
144
-
145
- # ------------ Public methods ------------
146
- def validate_targets(self):
147
- """ Validate each of the conditional targets used in the Gibbs steps """
148
- if not isinstance(self.target, JointDistribution):
149
- raise ValueError('Target distribution must be a JointDistribution.')
150
- for sampler in self.samplers.values():
151
- sampler.validate_target()
152
-
153
- def sample(self, Ns) -> 'HybridGibbs':
154
- """ Sample from the joint distribution using Gibbs sampling """
155
- for _ in tqdm(range(Ns)):
156
- self.step()
157
- self._store_samples()
158
-
159
- def warmup(self, Nb) -> 'HybridGibbs':
160
- """ Warmup (tune) the Gibbs sampler """
161
- for idx in tqdm(range(Nb)):
162
- self.step()
163
- self.tune(idx)
164
- self._store_samples()
165
-
166
- def get_samples(self) -> Dict[str, Samples]:
167
- samples_object = {}
168
- for par_name in self.par_names:
169
- samples_array = np.array(self.samples[par_name]).T
170
- samples_object[par_name] = Samples(samples_array, self.target.get_density(par_name).geometry)
171
- return samples_object
172
-
173
- def step(self):
174
- """ Sequentially go through all parameters and sample them conditionally on each other """
175
-
176
- # Sample from each conditional distribution
177
- for par_name in self.par_names:
178
-
179
- # Set target for current parameter
180
- self._set_target(par_name)
181
-
182
- # Get sampler
183
- sampler = self.samplers[par_name]
184
-
185
- # Set initial parameters using current point and scale (subset of state)
186
- # This does not store the full state from e.g. NUTS sampler
187
- # But works on samplers like MH, PCN, ULA, MALA, LinearRTO, UGLA, CWMH
188
- # that only use initial_point and initial_scale
189
- sampler.initial_point = self.current_samples[par_name]
190
- if hasattr(sampler, 'initial_scale'): sampler.initial_scale = sampler.scale
191
-
192
- # Reinitialize sampler
193
- # This makes the sampler lose all of its state.
194
- # This is only OK because we set the initial values above from the previous state
195
- sampler.reinitialize()
196
-
197
- # Run pre_warmup and pre_sample methods for sampler
198
- # TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
199
- self._pre_warmup_and_pre_sample_sampler(sampler)
200
-
201
- # Take MCMC steps
202
- for _ in range(self.num_sampling_steps[par_name]):
203
- sampler.step()
204
-
205
- # Extract samples (Ensure even 1-dimensional samples are 1D arrays)
206
- self.current_samples[par_name] = sampler.current_point.reshape(-1)
207
-
208
- def tune(self, idx):
209
- """ Tune each of the samplers """
210
- for par_name in self.par_names:
211
- self.samplers[par_name].tune(skip_len=1, update_count=idx)
212
-
213
- # ------------ Private methods ------------
214
- def _initialize_samplers(self):
215
- """ Initialize samplers """
216
- for sampler in self.samplers.values():
217
- sampler.initialize()
218
-
219
- def _initialize_num_sampling_steps(self):
220
- """ Initialize the number of sampling steps for each sampler. Defaults to 1 if not set by user """
221
-
222
- if self.num_sampling_steps is None:
223
- self.num_sampling_steps = {par_name: 1 for par_name in self.par_names}
224
-
225
- for par_name in self.par_names:
226
- if par_name not in self.num_sampling_steps:
227
- self.num_sampling_steps[par_name] = 1
228
-
229
-
230
- def _pre_warmup_and_pre_sample_sampler(self, sampler):
231
- if hasattr(sampler, '_pre_warmup'): sampler._pre_warmup()
232
- if hasattr(sampler, '_pre_sample'): sampler._pre_sample()
233
-
234
- def _set_targets(self):
235
- """ Set targets for all samplers using the current samples """
236
- par_names = self.par_names
237
- for par_name in par_names:
238
- self._set_target(par_name)
239
-
240
- def _set_target(self, par_name):
241
- """ Set target conditional distribution for a single parameter using the current samples """
242
- # Get all other conditional parameters other than the current parameter and update the target
243
- # This defines - from a joint p(x,y,z) - the conditional distribution p(x|y,z) or p(y|x,z) or p(z|x,y)
244
- conditional_params = {par_name_: self.current_samples[par_name_] for par_name_ in self.par_names if par_name_ != par_name}
245
- self.samplers[par_name].target = self.target(**conditional_params)
246
-
247
- def _allocate_samples(self):
248
- """ Allocate memory for samples """
249
- samples = {}
250
- for par_name in self.par_names:
251
- samples[par_name] = []
252
- self.samples = samples
253
-
254
- def _get_initial_points(self):
255
- """ Get initial points for each parameter """
256
- initial_points = {}
257
- for par_name in self.par_names:
258
- sampler = self.samplers[par_name]
259
- if sampler.initial_point is None:
260
- sampler.initial_point = sampler._get_default_initial_point(self.target.get_density(par_name).dim)
261
- initial_points[par_name] = sampler.initial_point
262
-
263
- return initial_points
264
-
265
- def _store_samples(self):
266
- """ Store current samples at index i of samples dict """
267
- for par_name in self.par_names:
268
- self.samples[par_name].append(self.current_samples[par_name])