CUQIpy 1.3.0.post0.dev401__py3-none-any.whl → 1.4.0.post0.dev41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (50) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/_joint_distribution.py +96 -11
  5. cuqi/experimental/__init__.py +1 -2
  6. cuqi/experimental/_recommender.py +4 -4
  7. cuqi/legacy/__init__.py +2 -0
  8. cuqi/legacy/sampler/__init__.py +11 -0
  9. cuqi/legacy/sampler/_conjugate.py +55 -0
  10. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  11. cuqi/legacy/sampler/_cwmh.py +196 -0
  12. cuqi/legacy/sampler/_gibbs.py +231 -0
  13. cuqi/legacy/sampler/_hmc.py +335 -0
  14. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  15. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  16. cuqi/legacy/sampler/_mh.py +190 -0
  17. cuqi/legacy/sampler/_pcn.py +244 -0
  18. cuqi/legacy/sampler/_rto.py +284 -0
  19. cuqi/legacy/sampler/_sampler.py +182 -0
  20. cuqi/problem/_problem.py +87 -80
  21. cuqi/sampler/__init__.py +120 -8
  22. cuqi/sampler/_conjugate.py +376 -35
  23. cuqi/sampler/_conjugate_approx.py +40 -16
  24. cuqi/sampler/_cwmh.py +132 -138
  25. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  26. cuqi/sampler/_gibbs.py +269 -130
  27. cuqi/sampler/_hmc.py +328 -201
  28. cuqi/sampler/_langevin_algorithm.py +282 -98
  29. cuqi/sampler/_laplace_approximation.py +87 -117
  30. cuqi/sampler/_mh.py +47 -157
  31. cuqi/sampler/_pcn.py +56 -211
  32. cuqi/sampler/_rto.py +206 -140
  33. cuqi/sampler/_sampler.py +540 -135
  34. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/METADATA +1 -1
  35. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/RECORD +38 -37
  36. cuqi/experimental/mcmc/__init__.py +0 -122
  37. cuqi/experimental/mcmc/_conjugate.py +0 -396
  38. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  39. cuqi/experimental/mcmc/_cwmh.py +0 -190
  40. cuqi/experimental/mcmc/_gibbs.py +0 -366
  41. cuqi/experimental/mcmc/_hmc.py +0 -462
  42. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
  43. cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
  44. cuqi/experimental/mcmc/_mh.py +0 -80
  45. cuqi/experimental/mcmc/_pcn.py +0 -89
  46. cuqi/experimental/mcmc/_rto.py +0 -350
  47. cuqi/experimental/mcmc/_sampler.py +0 -582
  48. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/WHEEL +0 -0
  49. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/licenses/LICENSE +0 -0
  50. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/top_level.txt +0 -0
@@ -1,76 +0,0 @@
1
- import numpy as np
2
- from cuqi.experimental.mcmc import Conjugate
3
- from cuqi.experimental.mcmc._conjugate import _ConjugatePair, _get_conjugate_parameter, _check_conjugate_parameter_is_scalar_reciprocal
4
- from cuqi.distribution import LMRF, Gamma
5
- import scipy as sp
6
-
7
- class ConjugateApprox(Conjugate):
8
- """ Approximate Conjugate sampler
9
-
10
- Sampler for sampling a posterior distribution where the likelihood and prior can be approximated
11
- by a conjugate pair.
12
-
13
- Currently supported pairs are:
14
- - (LMRF, Gamma): Approximated by (Gaussian, Gamma) where Gamma is defined on the inverse of the scale parameter of the LMRF distribution.
15
-
16
- Gamma distribution must be univariate.
17
-
18
- LMRF likelihood must have zero mean.
19
-
20
- For more details on conjugacy see :class:`Conjugate`.
21
-
22
- """
23
-
24
- def _set_conjugatepair(self):
25
- """ Set the conjugate pair based on the likelihood and prior. This requires target to be set. """
26
- if isinstance(self.target.likelihood.distribution, LMRF) and isinstance(self.target.prior, Gamma):
27
- self._conjugatepair = _LMRFGammaPair(self.target)
28
- else:
29
- raise ValueError(f"Conjugacy is not defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
30
-
31
-
32
- class _LMRFGammaPair(_ConjugatePair):
33
- """ Implementation of the conjugate pair (LMRF, Gamma) """
34
-
35
- def validate_target(self):
36
- if not self.target.prior.dim == 1:
37
- raise ValueError("Approximate conjugate sampler only works with univariate Gamma prior")
38
-
39
- if np.sum(self.target.likelihood.distribution.location) != 0:
40
- raise ValueError("Approximate conjugate sampler only works with zero mean LMRF likelihood")
41
-
42
- key_value_pairs = _get_conjugate_parameter(self.target)
43
- if len(key_value_pairs) != 1:
44
- raise ValueError(f"Multiple references to conjugate parameter {self.target.prior.name} found in likelihood. Only one occurance is supported.")
45
- for key, value in key_value_pairs:
46
- if key == "scale":
47
- if not _check_conjugate_parameter_is_scalar_reciprocal(value):
48
- raise ValueError("Approximate conjugate sampler only works with Gamma prior on the inverse of the scale parameter of the LMRF likelihood")
49
- else:
50
- raise ValueError(f"No approximate conjugacy defined for likelihood {type(self.target.likelihood.distribution)} and prior {type(self.target.prior)}, in CUQIpy")
51
-
52
- def conjugate_distribution(self):
53
- # Extract variables
54
- # Here we approximate the LMRF with a Gaussian
55
-
56
- # Extract diff_op from target likelihood
57
- D = self.target.likelihood.distribution._diff_op
58
- n = D.shape[0]
59
-
60
- # Gaussian approximation of LMRF prior as function of x_k
61
- # See Uribe et al. (2022) for details
62
- # Current has a zero mean assumption on likelihood! TODO
63
- beta=1e-5
64
- def Lk_fun(x_k):
65
- dd = 1/np.sqrt((D @ x_k)**2 + beta*np.ones(n))
66
- W = sp.sparse.diags(dd)
67
- return W.sqrt() @ D
68
-
69
- x = self.target.likelihood.data #x
70
- d = len(x) #d
71
- Lx = Lk_fun(x)@x #Lx
72
- alpha = self.target.prior.shape #alpha
73
- beta = self.target.prior.rate #beta
74
-
75
- # Create Gamma distribution and sample
76
- return Gamma(shape=d+alpha, rate=np.linalg.norm(Lx)**2+beta)
@@ -1,190 +0,0 @@
1
- import numpy as np
2
- import cuqi
3
- from cuqi.experimental.mcmc import ProposalBasedSampler
4
- from cuqi.array import CUQIarray
5
- from numbers import Number
6
-
7
- class CWMH(ProposalBasedSampler):
8
- """Component-wise Metropolis Hastings sampler.
9
-
10
- Allows sampling of a target distribution by a component-wise random-walk
11
- sampling of a proposal distribution along with an accept/reject step.
12
-
13
- Parameters
14
- ----------
15
-
16
- target : `cuqi.distribution.Distribution` or lambda function
17
- The target distribution to sample. Custom logpdfs are supported by using
18
- a :class:`cuqi.distribution.UserDefinedDistribution`.
19
-
20
- proposal : `cuqi.distribution.Distribution` or callable method
21
- The proposal to sample from. If a callable method it should provide a
22
- single independent sample from proposal distribution. Defaults to a
23
- Gaussian proposal. *Optional*.
24
-
25
- scale : float or ndarray
26
- Scale parameter used to define correlation between previous and proposed
27
- sample in random-walk. *Optional*. If float, the same scale is used for
28
- all dimensions. If ndarray, a (possibly) different scale is used for
29
- each dimension.
30
-
31
- initial_point : ndarray
32
- Initial parameters. *Optional*
33
-
34
- callback : callable, optional
35
- A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
36
- The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
37
-
38
- kwargs : dict
39
- Additional keyword arguments to be passed to the base class
40
- :class:`ProposalBasedSampler`.
41
-
42
- Example
43
- -------
44
- .. code-block:: python
45
- import numpy as np
46
- import cuqi
47
- # Parameters
48
- dim = 5 # Dimension of distribution
49
- mu = np.arange(dim) # Mean of Gaussian
50
- std = 1 # standard deviation of Gaussian
51
-
52
- # Logpdf function
53
- logpdf_func = lambda x: -1/(std**2)*np.sum((x-mu)**2)
54
-
55
- # Define distribution from logpdf as UserDefinedDistribution (sample
56
- # and gradients also supported as inputs to UserDefinedDistribution)
57
- target = cuqi.distribution.UserDefinedDistribution(
58
- dim=dim, logpdf_func=logpdf_func)
59
-
60
- # Set up sampler
61
- sampler = cuqi.experimental.mcmc.CWMH(target, scale=1)
62
-
63
- # Sample
64
- samples = sampler.sample(2000).get_samples()
65
-
66
- """
67
-
68
- _STATE_KEYS = ProposalBasedSampler._STATE_KEYS.union(['_scale_temp'])
69
-
70
- def __init__(self, target:cuqi.density.Density=None, proposal=None, scale=1,
71
- initial_point=None, **kwargs):
72
- super().__init__(target, proposal=proposal, scale=scale,
73
- initial_point=initial_point, **kwargs)
74
-
75
- def _initialize(self):
76
- if isinstance(self.scale, Number):
77
- self.scale = np.ones(self.dim)*self.scale
78
- self._acc = [np.ones((self.dim))] # Overwrite acc from ProposalBasedSampler with list of arrays
79
-
80
- # Handling of temporary scale parameter due to possible bug in old CWMH
81
- self._scale_temp = self.scale.copy()
82
-
83
- @property
84
- def scale(self):
85
- """ Get the scale parameter. """
86
- return self._scale
87
-
88
- @scale.setter
89
- def scale(self, value):
90
- """ Set the scale parameter. """
91
- if self._is_initialized and isinstance(value, Number):
92
- value = np.ones(self.dim)*value
93
- self._scale = value
94
-
95
- def validate_target(self):
96
- if not isinstance(self.target, cuqi.density.Density):
97
- raise ValueError(
98
- "Target should be an instance of "+\
99
- f"{cuqi.density.Density.__class__.__name__}")
100
- # Fail when there is no log density, which is currently assumed to be the case in case NaN is returned.
101
- if np.isnan(self.target.logd(self._get_default_initial_point(self.dim))):
102
- raise ValueError("Target does not have valid logd")
103
-
104
- def validate_proposal(self):
105
- if not isinstance(self.proposal, cuqi.distribution.Distribution):
106
- raise ValueError("Proposal must be a cuqi.distribution.Distribution object")
107
- if not self.proposal.is_symmetric:
108
- raise ValueError("Proposal must be symmetric")
109
-
110
- @property
111
- def proposal(self):
112
- if self._proposal is None:
113
- self._proposal = cuqi.distribution.Normal(
114
- mean=lambda location: location,
115
- std=lambda scale: scale,
116
- geometry=self.dim,
117
- )
118
- return self._proposal
119
-
120
- @proposal.setter
121
- def proposal(self, value):
122
- self._proposal = value
123
-
124
- def step(self):
125
- # Initialize x_t which is used to store the current CWMH sample
126
- x_t = self.current_point.copy()
127
-
128
- # Initialize x_star which is used to store the proposed sample by
129
- # updating the current sample component-by-component
130
- x_star = self.current_point.copy()
131
-
132
- # Propose a sample x_all_components from the proposal distribution
133
- # for all the components
134
- target_eval_t = self.current_target_logd
135
- if isinstance(self.proposal,cuqi.distribution.Distribution):
136
- x_all_components = self.proposal(
137
- location= self.current_point, scale=self.scale).sample()
138
- else:
139
- x_all_components = self.proposal(self.current_point, self.scale)
140
-
141
- # Initialize acceptance rate
142
- acc = np.zeros(self.dim)
143
-
144
- # Loop over all the components of the sample and accept/reject
145
- # each component update.
146
- for j in range(self.dim):
147
- # propose state x_star by updating the j-th component
148
- x_star[j] = x_all_components[j]
149
-
150
- # evaluate target
151
- target_eval_star = self.target.logd(x_star)
152
-
153
- # compute Metropolis acceptance ratio
154
- alpha = min(0, target_eval_star - target_eval_t)
155
-
156
- # accept/reject
157
- u_theta = np.log(np.random.rand())
158
- if (u_theta <= alpha) and \
159
- (not np.isnan(target_eval_star)) and \
160
- (not np.isinf(target_eval_star)):
161
- x_t[j] = x_all_components[j]
162
- target_eval_t = target_eval_star
163
- acc[j] = 1
164
-
165
- x_star = x_t.copy()
166
-
167
- self.current_target_logd = target_eval_t
168
- self.current_point = x_t
169
-
170
- return acc
171
-
172
- def tune(self, skip_len, update_count):
173
- # Store update_count in variable i for readability
174
- i = update_count
175
-
176
- # Optimal acceptance rate for CWMH
177
- star_acc = 0.21/self.dim + 0.23
178
-
179
- # Mean of acceptance rate over the last skip_len samples
180
- hat_acc = np.mean(self._acc[i*skip_len:(i+1)*skip_len], axis=0)
181
-
182
- # Compute new intermediate scaling parameter scale_temp
183
- # Factor zeta ensures that the variation of the scale update vanishes
184
- zeta = 1/np.sqrt(update_count+1)
185
- scale_temp = np.exp(
186
- np.log(self._scale_temp) + zeta*(hat_acc-star_acc))
187
-
188
- # Update the scale parameter
189
- self.scale = np.minimum(scale_temp, np.ones(self.dim))
190
- self._scale_temp = scale_temp
@@ -1,366 +0,0 @@
1
- from cuqi.distribution import JointDistribution, Posterior
2
- from cuqi.experimental.mcmc import Sampler
3
- from cuqi.samples import Samples, JointSamples
4
- from typing import Dict
5
- import numpy as np
6
- import warnings
7
-
8
- try:
9
- from tqdm import tqdm
10
- except ImportError:
11
- def tqdm(iterable, **kwargs):
12
- warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
13
- return iterable
14
-
15
- # Not subclassed from Sampler as Gibbs handles multiple samplers and samples multiple parameters
16
- # Similar approach as for JointDistribution
17
- class HybridGibbs:
18
- """
19
- Hybrid Gibbs sampler for sampling a joint distribution.
20
-
21
- Gibbs sampling samples the variables of the distribution sequentially,
22
- one variable at a time. When a variable represents a random vector, the
23
- whole vector is sampled simultaneously.
24
-
25
- The sampling of each variable is done by sampling from the conditional
26
- distribution of that variable given the values of the other variables.
27
- This is often a very efficient way of sampling from a joint distribution
28
- if the conditional distributions are easy to sample from.
29
-
30
- Hybrid Gibbs sampler is a generalization of the Gibbs sampler where the
31
- conditional distributions are sampled using different MCMC samplers.
32
-
33
- When the conditionals are sampled exactly, the samples from the Gibbs
34
- sampler converge to the joint distribution. See e.g.
35
- Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
36
- for more details.
37
-
38
- In each Gibbs step, the corresponding sampler state and history are stored,
39
- then the sampler is reinitialized. After reinitialization, the sampler state
40
- and history are set back to the stored values. This ensures preserving the
41
- statefulness of the samplers.
42
-
43
- The order in which the conditionals are sampled is the order of the
44
- variables in the sampling strategy, unless a different sampling order
45
- is specified by the parameter `scan_order`
46
-
47
- Parameters
48
- ----------
49
- target : cuqi.distribution.JointDistribution
50
- Target distribution to sample from.
51
-
52
- sampling_strategy : dict
53
- Dictionary of sampling strategies for each variable.
54
- Keys are variable names.
55
- Values are sampler objects.
56
-
57
- num_sampling_steps : dict, *optional*
58
- Dictionary of number of sampling steps for each variable.
59
- The sampling steps are defined as the number of times the sampler
60
- will call its step method in each Gibbs step.
61
- Default is 1 for all variables.
62
-
63
- scan_order : list or str, *optional*
64
- Order in which the conditional distributions are sampled.
65
- If set to "random", use a random ordering at each step.
66
- If not specified, it will be the order in the sampling_strategy.
67
-
68
- callback : callable, optional
69
- A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
70
- The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
71
-
72
- Example
73
- -------
74
- .. code-block:: python
75
-
76
- import cuqi
77
- import numpy as np
78
-
79
- # Model and data
80
- A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='sinc').get_components()
81
- n = A.domain_dim
82
-
83
- # Define distributions
84
- d = cuqi.distribution.Gamma(1, 1e-4)
85
- l = cuqi.distribution.Gamma(1, 1e-4)
86
- x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
87
- y = cuqi.distribution.Gaussian(A, lambda l: 1/l)
88
-
89
- # Combine into a joint distribution and create posterior
90
- joint = cuqi.distribution.JointDistribution(d, l, x, y)
91
- posterior = joint(y=y_obs)
92
-
93
- # Define sampling strategy
94
- sampling_strategy = {
95
- 'x': cuqi.experimental.mcmc.LinearRTO(maxit=15),
96
- 'd': cuqi.experimental.mcmc.Conjugate(),
97
- 'l': cuqi.experimental.mcmc.Conjugate(),
98
- }
99
-
100
- # Define Gibbs sampler
101
- sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
102
-
103
- # Run sampler
104
- sampler.warmup(200)
105
- sampler.sample(1000)
106
-
107
- # Get samples removing burn-in
108
- samples = sampler.get_samples().burnthin(200)
109
-
110
- # Plot results
111
- samples['x'].plot_ci(exact=probinfo.exactSolution)
112
- samples['d'].plot_trace(figsize=(8,2))
113
- samples['l'].plot_trace(figsize=(8,2))
114
-
115
- """
116
-
117
- def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
118
-
119
- # Store target and allow conditioning to reduce to a single density
120
- self.target = target() # Create a copy of target distribution (to avoid modifying the original)
121
-
122
- # Store sampler instances (again as a copy to avoid modifying the original)
123
- self.samplers = sampling_strategy.copy()
124
-
125
- # Store number of sampling steps for each parameter
126
- self.num_sampling_steps = num_sampling_steps
127
-
128
- # Store parameter names
129
- self.par_names = self.target.get_parameter_names()
130
-
131
- # Store the scan order
132
- self._scan_order = scan_order
133
-
134
- # Check that the parameters of the target align with the sampling_strategy and scan_order
135
- if set(self.par_names) != set(self.scan_order):
136
- raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
137
-
138
- # Initialize sampler (after target is set)
139
- self._initialize()
140
-
141
- # Set the callback function
142
- self.callback = callback
143
-
144
- def _initialize(self):
145
- """ Initialize sampler """
146
-
147
- # Initial points
148
- self.current_samples = self._get_initial_points()
149
-
150
- # Initialize sampling steps
151
- self._initialize_num_sampling_steps()
152
-
153
- # Allocate samples
154
- self._allocate_samples()
155
-
156
- # Set targets
157
- self._set_targets()
158
-
159
- # Initialize the samplers
160
- self._initialize_samplers()
161
-
162
- # Validate all targets for samplers.
163
- self.validate_targets()
164
-
165
- @property
166
- def scan_order(self):
167
- if self._scan_order is None:
168
- return list(self.samplers.keys())
169
- if self._scan_order == "random":
170
- arr = list(self.samplers.keys())
171
- np.random.shuffle(arr) # Shuffle works in-place
172
- return arr
173
- return self._scan_order
174
-
175
- # ------------ Public methods ------------
176
- def validate_targets(self):
177
- """ Validate each of the conditional targets used in the Gibbs steps """
178
- if not isinstance(self.target, (JointDistribution, Posterior)):
179
- raise ValueError('Target distribution must be a JointDistribution or Posterior.')
180
- for sampler in self.samplers.values():
181
- sampler.validate_target()
182
-
183
- def sample(self, Ns) -> 'HybridGibbs':
184
- """ Sample from the joint distribution using Gibbs sampling
185
-
186
- Parameters
187
- ----------
188
- Ns : int
189
- The number of samples to draw.
190
-
191
- """
192
- for idx in tqdm(range(Ns), "Sample: "):
193
-
194
- self.step()
195
-
196
- self._store_samples()
197
-
198
- # Call callback function if specified
199
- self._call_callback(idx, Ns)
200
-
201
- return self
202
-
203
- def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
204
- """ Warmup (tune) the samplers in the Gibbs sampling scheme
205
-
206
- Parameters
207
- ----------
208
- Nb : int
209
- The number of samples to draw during warmup.
210
-
211
- tune_freq : float, optional
212
- Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
213
-
214
- """
215
-
216
- tune_interval = max(int(tune_freq * Nb), 1)
217
-
218
- for idx in tqdm(range(Nb), "Warmup: "):
219
-
220
- self.step()
221
-
222
- # Tune the sampler at tuning intervals (matching behavior of Sampler class)
223
- if (idx + 1) % tune_interval == 0:
224
- self.tune(tune_interval, idx // tune_interval)
225
-
226
- self._store_samples()
227
-
228
- # Call callback function if specified
229
- self._call_callback(idx, Nb)
230
-
231
- return self
232
-
233
- def get_samples(self) -> Dict[str, Samples]:
234
- samples_object = JointSamples()
235
- for par_name in self.par_names:
236
- samples_array = np.array(self.samples[par_name]).T
237
- samples_object[par_name] = Samples(samples_array, self.target.get_density(par_name).geometry)
238
- return samples_object
239
-
240
- def step(self):
241
- """ Sequentially go through all parameters and sample them conditionally on each other """
242
-
243
- # Sample from each conditional distribution
244
- for par_name in self.scan_order:
245
-
246
- # Set target for current parameter
247
- self._set_target(par_name)
248
-
249
- # Get sampler
250
- sampler = self.samplers[par_name]
251
-
252
- # Instead of simply changing the target of the sampler, we reinitialize it.
253
- # This is to ensure that all internal variables are set to match the new target.
254
- # To return the sampler to the old state and history, we first extract the state and history
255
- # before reinitializing the sampler and then set the state and history back to the sampler
256
-
257
- # Extract state and history from sampler
258
- sampler_state = sampler.get_state()
259
- sampler_history = sampler.get_history()
260
-
261
- # Reinitialize sampler
262
- sampler.reinitialize()
263
-
264
- # Set state and history back to sampler
265
- sampler.set_state(sampler_state)
266
- sampler.set_history(sampler_history)
267
-
268
- # Allow for multiple sampling steps in each Gibbs step
269
- for _ in range(self.num_sampling_steps[par_name]):
270
- # Sampling step
271
- acc = sampler.step()
272
-
273
- # Store acceptance rate in sampler (matching behavior of Sampler class Sample method)
274
- sampler._acc.append(acc)
275
-
276
- # Extract samples (Ensure even 1-dimensional samples are 1D arrays)
277
- if isinstance(sampler.current_point, np.ndarray):
278
- self.current_samples[par_name] = sampler.current_point.reshape(-1)
279
- else:
280
- self.current_samples[par_name] = sampler.current_point
281
-
282
- def tune(self, skip_len, update_count):
283
- """ Run a single tuning step on each of the samplers in the Gibbs sampling scheme
284
-
285
- Parameters
286
- ----------
287
- skip_len : int
288
- Defines the number of steps in between tuning (i.e. the tuning interval).
289
-
290
- update_count : int
291
- The number of times tuning has been performed. Can be used for internal bookkeeping.
292
-
293
- """
294
- for par_name in self.par_names:
295
- self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
296
-
297
- # ------------ Private methods ------------
298
- def _call_callback(self, sample_index, num_of_samples):
299
- """ Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
300
- if self.callback is not None:
301
- self.callback(self, sample_index, num_of_samples)
302
-
303
- def _initialize_samplers(self):
304
- """ Initialize samplers """
305
- for sampler in self.samplers.values():
306
- sampler.initialize()
307
-
308
- def _initialize_num_sampling_steps(self):
309
- """ Initialize the number of sampling steps for each sampler. Defaults to 1 if not set by user """
310
-
311
- if self.num_sampling_steps is None:
312
- self.num_sampling_steps = {par_name: 1 for par_name in self.par_names}
313
-
314
- for par_name in self.par_names:
315
- if par_name not in self.num_sampling_steps:
316
- self.num_sampling_steps[par_name] = 1
317
-
318
-
319
- def _set_targets(self):
320
- """ Set targets for all samplers using the current samples """
321
- par_names = self.par_names
322
- for par_name in par_names:
323
- self._set_target(par_name)
324
-
325
- def _set_target(self, par_name):
326
- """ Set target conditional distribution for a single parameter using the current samples """
327
- # Get all other conditional parameters other than the current parameter and update the target
328
- # This defines - from a joint p(x,y,z) - the conditional distribution p(x|y,z) or p(y|x,z) or p(z|x,y)
329
- conditional_params = {par_name_: self.current_samples[par_name_] for par_name_ in self.par_names if par_name_ != par_name}
330
- self.samplers[par_name].target = self.target(**conditional_params)
331
-
332
- def _allocate_samples(self):
333
- """ Allocate memory for samples """
334
- samples = {}
335
- for par_name in self.par_names:
336
- samples[par_name] = []
337
- self.samples = samples
338
-
339
- def _get_initial_points(self):
340
- """ Get initial points for each parameter """
341
- initial_points = {}
342
- for par_name in self.par_names:
343
- sampler = self.samplers[par_name]
344
- if sampler.initial_point is None:
345
- sampler.initial_point = sampler._get_default_initial_point(self.target.get_density(par_name).dim)
346
- initial_points[par_name] = sampler.initial_point
347
-
348
- return initial_points
349
-
350
- def _store_samples(self):
351
- """ Store current samples at index i of samples dict """
352
- for par_name in self.par_names:
353
- self.samples[par_name].append(self.current_samples[par_name])
354
-
355
- def __repr__(self):
356
- """ Return a string representation of the sampler. """
357
- msg = f"Sampler: {self.__class__.__name__} \n"
358
- if self.target is None:
359
- msg += f" Target: None \n"
360
- else:
361
- msg += f" Target: \n \t {self.target} \n\n"
362
-
363
- for key, value in zip(self.samplers.keys(), self.samplers.values()):
364
- msg += f" Variable '{key}' with {value} \n"
365
-
366
- return msg