CUQIpy 1.1.1.post0.dev36__py3-none-any.whl → 1.4.1.post0.dev124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- cuqi/__init__.py +2 -0
- cuqi/_version.py +3 -3
- cuqi/algebra/__init__.py +2 -0
- cuqi/algebra/_abstract_syntax_tree.py +358 -0
- cuqi/algebra/_ordered_set.py +82 -0
- cuqi/algebra/_random_variable.py +457 -0
- cuqi/array/_array.py +4 -13
- cuqi/config.py +7 -0
- cuqi/density/_density.py +9 -1
- cuqi/distribution/__init__.py +3 -2
- cuqi/distribution/_beta.py +7 -11
- cuqi/distribution/_cauchy.py +2 -2
- cuqi/distribution/_custom.py +0 -6
- cuqi/distribution/_distribution.py +31 -45
- cuqi/distribution/_gamma.py +7 -3
- cuqi/distribution/_gaussian.py +2 -12
- cuqi/distribution/_inverse_gamma.py +4 -10
- cuqi/distribution/_joint_distribution.py +112 -15
- cuqi/distribution/_lognormal.py +0 -7
- cuqi/distribution/{_modifiedhalfnormal.py → _modified_half_normal.py} +23 -23
- cuqi/distribution/_normal.py +34 -7
- cuqi/distribution/_posterior.py +9 -0
- cuqi/distribution/_truncated_normal.py +129 -0
- cuqi/distribution/_uniform.py +47 -1
- cuqi/experimental/__init__.py +2 -2
- cuqi/experimental/_recommender.py +216 -0
- cuqi/geometry/__init__.py +2 -0
- cuqi/geometry/_geometry.py +15 -1
- cuqi/geometry/_product_geometry.py +181 -0
- cuqi/implicitprior/__init__.py +5 -3
- cuqi/implicitprior/_regularized_gaussian.py +483 -0
- cuqi/implicitprior/{_regularizedGMRF.py → _regularized_gmrf.py} +4 -2
- cuqi/implicitprior/{_regularizedUnboundedUniform.py → _regularized_unbounded_uniform.py} +3 -2
- cuqi/implicitprior/_restorator.py +269 -0
- cuqi/legacy/__init__.py +2 -0
- cuqi/{experimental/mcmc → legacy/sampler}/__init__.py +7 -11
- cuqi/legacy/sampler/_conjugate.py +55 -0
- cuqi/legacy/sampler/_conjugate_approx.py +52 -0
- cuqi/legacy/sampler/_cwmh.py +196 -0
- cuqi/legacy/sampler/_gibbs.py +231 -0
- cuqi/legacy/sampler/_hmc.py +335 -0
- cuqi/{experimental/mcmc → legacy/sampler}/_langevin_algorithm.py +82 -111
- cuqi/legacy/sampler/_laplace_approximation.py +184 -0
- cuqi/legacy/sampler/_mh.py +190 -0
- cuqi/legacy/sampler/_pcn.py +244 -0
- cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +132 -90
- cuqi/legacy/sampler/_sampler.py +182 -0
- cuqi/likelihood/_likelihood.py +9 -1
- cuqi/model/__init__.py +1 -1
- cuqi/model/_model.py +1361 -359
- cuqi/pde/__init__.py +4 -0
- cuqi/pde/_observation_map.py +36 -0
- cuqi/pde/_pde.py +134 -33
- cuqi/problem/_problem.py +93 -87
- cuqi/sampler/__init__.py +120 -8
- cuqi/sampler/_conjugate.py +376 -35
- cuqi/sampler/_conjugate_approx.py +40 -16
- cuqi/sampler/_cwmh.py +132 -138
- cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
- cuqi/sampler/_gibbs.py +288 -130
- cuqi/sampler/_hmc.py +328 -201
- cuqi/sampler/_langevin_algorithm.py +284 -100
- cuqi/sampler/_laplace_approximation.py +87 -117
- cuqi/sampler/_mh.py +47 -157
- cuqi/sampler/_pcn.py +65 -213
- cuqi/sampler/_rto.py +211 -142
- cuqi/sampler/_sampler.py +553 -136
- cuqi/samples/__init__.py +1 -1
- cuqi/samples/_samples.py +24 -18
- cuqi/solver/__init__.py +6 -4
- cuqi/solver/_solver.py +230 -26
- cuqi/testproblem/_testproblem.py +2 -3
- cuqi/utilities/__init__.py +6 -1
- cuqi/utilities/_get_python_variable_name.py +2 -2
- cuqi/utilities/_utilities.py +182 -2
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/METADATA +10 -6
- cuqipy-1.4.1.post0.dev124.dist-info/RECORD +101 -0
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/WHEEL +1 -1
- CUQIpy-1.1.1.post0.dev36.dist-info/RECORD +0 -92
- cuqi/experimental/mcmc/_conjugate.py +0 -197
- cuqi/experimental/mcmc/_conjugate_approx.py +0 -81
- cuqi/experimental/mcmc/_cwmh.py +0 -191
- cuqi/experimental/mcmc/_gibbs.py +0 -268
- cuqi/experimental/mcmc/_hmc.py +0 -470
- cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
- cuqi/experimental/mcmc/_mh.py +0 -78
- cuqi/experimental/mcmc/_pcn.py +0 -89
- cuqi/experimental/mcmc/_sampler.py +0 -561
- cuqi/experimental/mcmc/_utilities.py +0 -17
- cuqi/implicitprior/_regularizedGaussian.py +0 -323
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info/licenses}/LICENSE +0 -0
- {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/top_level.txt +0 -0
cuqi/sampler/_gibbs.py
CHANGED
|
@@ -1,14 +1,23 @@
|
|
|
1
|
-
from cuqi.distribution import JointDistribution
|
|
1
|
+
from cuqi.distribution import JointDistribution, Posterior
|
|
2
2
|
from cuqi.sampler import Sampler
|
|
3
|
-
from cuqi.samples import Samples
|
|
4
|
-
from typing import Dict
|
|
3
|
+
from cuqi.samples import Samples, JointSamples
|
|
4
|
+
from typing import Dict
|
|
5
5
|
import numpy as np
|
|
6
|
-
import
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
6
|
+
import warnings
|
|
7
|
+
from cuqi import config
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
except ImportError:
|
|
12
|
+
def tqdm(iterable, **kwargs):
|
|
13
|
+
warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
|
|
14
|
+
return iterable
|
|
15
|
+
|
|
16
|
+
# Not subclassed from Sampler as Gibbs handles multiple samplers and samples multiple parameters
|
|
17
|
+
# Similar approach as for JointDistribution
|
|
18
|
+
class HybridGibbs:
|
|
10
19
|
"""
|
|
11
|
-
Gibbs sampler for sampling a joint distribution.
|
|
20
|
+
Hybrid Gibbs sampler for sampling a joint distribution.
|
|
12
21
|
|
|
13
22
|
Gibbs sampling samples the variables of the distribution sequentially,
|
|
14
23
|
one variable at a time. When a variable represents a random vector, the
|
|
@@ -17,7 +26,24 @@ class Gibbs:
|
|
|
17
26
|
The sampling of each variable is done by sampling from the conditional
|
|
18
27
|
distribution of that variable given the values of the other variables.
|
|
19
28
|
This is often a very efficient way of sampling from a joint distribution
|
|
20
|
-
if the conditional distributions are easy to sample from.
|
|
29
|
+
if the conditional distributions are easy to sample from.
|
|
30
|
+
|
|
31
|
+
Hybrid Gibbs sampler is a generalization of the Gibbs sampler where the
|
|
32
|
+
conditional distributions are sampled using different MCMC samplers.
|
|
33
|
+
|
|
34
|
+
When the conditionals are sampled exactly, the samples from the Gibbs
|
|
35
|
+
sampler converge to the joint distribution. See e.g.
|
|
36
|
+
Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
|
|
37
|
+
for more details.
|
|
38
|
+
|
|
39
|
+
In each Gibbs step, the corresponding sampler state and history are stored,
|
|
40
|
+
then the sampler is reinitialized. After reinitialization, the sampler state
|
|
41
|
+
and history are set back to the stored values. This ensures preserving the
|
|
42
|
+
statefulness of the samplers.
|
|
43
|
+
|
|
44
|
+
The order in which the conditionals are sampled is the order of the
|
|
45
|
+
variables in the sampling strategy, unless a different sampling order
|
|
46
|
+
is specified by the parameter `scan_order`
|
|
21
47
|
|
|
22
48
|
Parameters
|
|
23
49
|
----------
|
|
@@ -25,10 +51,25 @@ class Gibbs:
|
|
|
25
51
|
Target distribution to sample from.
|
|
26
52
|
|
|
27
53
|
sampling_strategy : dict
|
|
28
|
-
Dictionary of sampling strategies for each
|
|
29
|
-
Keys are
|
|
54
|
+
Dictionary of sampling strategies for each variable.
|
|
55
|
+
Keys are variable names.
|
|
30
56
|
Values are sampler objects.
|
|
31
57
|
|
|
58
|
+
num_sampling_steps : dict, *optional*
|
|
59
|
+
Dictionary of number of sampling steps for each variable.
|
|
60
|
+
The sampling steps are defined as the number of times the sampler
|
|
61
|
+
will call its step method in each Gibbs step.
|
|
62
|
+
Default is 1 for all variables.
|
|
63
|
+
|
|
64
|
+
scan_order : list or str, *optional*
|
|
65
|
+
Order in which the conditional distributions are sampled.
|
|
66
|
+
If set to "random", use a random ordering at each step.
|
|
67
|
+
If not specified, it will be the order in the sampling_strategy.
|
|
68
|
+
|
|
69
|
+
callback : callable, optional
|
|
70
|
+
A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
|
|
71
|
+
The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
|
|
72
|
+
|
|
32
73
|
Example
|
|
33
74
|
-------
|
|
34
75
|
.. code-block:: python
|
|
@@ -37,7 +78,7 @@ class Gibbs:
|
|
|
37
78
|
import numpy as np
|
|
38
79
|
|
|
39
80
|
# Model and data
|
|
40
|
-
A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='
|
|
81
|
+
A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='sinc').get_components()
|
|
41
82
|
n = A.domain_dim
|
|
42
83
|
|
|
43
84
|
# Define distributions
|
|
@@ -52,15 +93,20 @@ class Gibbs:
|
|
|
52
93
|
|
|
53
94
|
# Define sampling strategy
|
|
54
95
|
sampling_strategy = {
|
|
55
|
-
'x': cuqi.sampler.LinearRTO,
|
|
56
|
-
|
|
96
|
+
'x': cuqi.sampler.LinearRTO(maxit=15),
|
|
97
|
+
'd': cuqi.sampler.Conjugate(),
|
|
98
|
+
'l': cuqi.sampler.Conjugate(),
|
|
57
99
|
}
|
|
58
100
|
|
|
59
101
|
# Define Gibbs sampler
|
|
60
|
-
sampler = cuqi.sampler.
|
|
102
|
+
sampler = cuqi.sampler.HybridGibbs(posterior, sampling_strategy)
|
|
61
103
|
|
|
62
104
|
# Run sampler
|
|
63
|
-
|
|
105
|
+
sampler.warmup(200)
|
|
106
|
+
sampler.sample(1000)
|
|
107
|
+
|
|
108
|
+
# Get samples removing burn-in
|
|
109
|
+
samples = sampler.get_samples().burnthin(200)
|
|
64
110
|
|
|
65
111
|
# Plot results
|
|
66
112
|
samples['x'].plot_ci(exact=probinfo.exactSolution)
|
|
@@ -69,159 +115,271 @@ class Gibbs:
|
|
|
69
115
|
|
|
70
116
|
"""
|
|
71
117
|
|
|
72
|
-
def __init__(self, target: JointDistribution, sampling_strategy: Dict[
|
|
118
|
+
def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
|
|
73
119
|
|
|
74
120
|
# Store target and allow conditioning to reduce to a single density
|
|
75
121
|
self.target = target() # Create a copy of target distribution (to avoid modifying the original)
|
|
76
122
|
|
|
77
|
-
#
|
|
78
|
-
self.samplers =
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
self.samplers[par_name_] = sampling_strategy[par_name]
|
|
83
|
-
else:
|
|
84
|
-
self.samplers[par_name] = sampling_strategy[par_name]
|
|
123
|
+
# Store sampler instances (again as a copy to avoid modifying the original)
|
|
124
|
+
self.samplers = sampling_strategy.copy()
|
|
125
|
+
|
|
126
|
+
# Store number of sampling steps for each parameter
|
|
127
|
+
self.num_sampling_steps = num_sampling_steps
|
|
85
128
|
|
|
86
129
|
# Store parameter names
|
|
87
130
|
self.par_names = self.target.get_parameter_names()
|
|
88
131
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
132
|
+
# Store the scan order
|
|
133
|
+
self._scan_order = scan_order
|
|
134
|
+
|
|
135
|
+
# Check that the parameters of the target align with the sampling_strategy and scan_order
|
|
136
|
+
if set(self.par_names) != set(self.scan_order):
|
|
137
|
+
raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
|
|
138
|
+
|
|
139
|
+
# Initialize sampler (after target is set)
|
|
140
|
+
self._initialize()
|
|
141
|
+
|
|
142
|
+
# Set the callback function
|
|
143
|
+
self.callback = callback
|
|
144
|
+
|
|
145
|
+
def _initialize(self):
|
|
146
|
+
""" Initialize sampler """
|
|
92
147
|
|
|
93
148
|
# Initial points
|
|
94
|
-
current_samples = self._get_initial_points()
|
|
149
|
+
self.current_samples = self._get_initial_points()
|
|
95
150
|
|
|
96
|
-
#
|
|
97
|
-
|
|
98
|
-
at_Ns = self._Ns
|
|
151
|
+
# Initialize sampling steps
|
|
152
|
+
self._initialize_num_sampling_steps()
|
|
99
153
|
|
|
100
|
-
# Allocate
|
|
101
|
-
self.
|
|
102
|
-
self._allocate_samples(Ns)
|
|
154
|
+
# Allocate samples
|
|
155
|
+
self._allocate_samples()
|
|
103
156
|
|
|
104
|
-
#
|
|
105
|
-
|
|
106
|
-
current_samples = self.step_tune(current_samples)
|
|
107
|
-
self._store_samples(self.samples_warmup, current_samples, i)
|
|
108
|
-
self._print_progress(i+1+at_Nb, at_Nb+Nb, 'Warmup')
|
|
157
|
+
# Set targets
|
|
158
|
+
self._set_targets()
|
|
109
159
|
|
|
110
|
-
#
|
|
111
|
-
|
|
112
|
-
current_samples = self.step(current_samples)
|
|
113
|
-
self._store_samples(self.samples, current_samples, i)
|
|
114
|
-
self._print_progress(i+1, at_Ns+Ns, 'Sample')
|
|
160
|
+
# Initialize the samplers
|
|
161
|
+
self._initialize_samplers()
|
|
115
162
|
|
|
116
|
-
#
|
|
117
|
-
|
|
163
|
+
# Validate all targets for samplers.
|
|
164
|
+
self.validate_targets()
|
|
118
165
|
|
|
119
|
-
|
|
120
|
-
|
|
166
|
+
@property
|
|
167
|
+
def scan_order(self):
|
|
168
|
+
if self._scan_order is None:
|
|
169
|
+
return list(self.samplers.keys())
|
|
170
|
+
if self._scan_order == "random":
|
|
171
|
+
arr = list(self.samplers.keys())
|
|
172
|
+
np.random.shuffle(arr) # Shuffle works in-place
|
|
173
|
+
return arr
|
|
174
|
+
return self._scan_order
|
|
121
175
|
|
|
122
|
-
|
|
123
|
-
|
|
176
|
+
# ------------ Public methods ------------
|
|
177
|
+
def validate_targets(self):
|
|
178
|
+
""" Validate each of the conditional targets used in the Gibbs steps """
|
|
179
|
+
if not isinstance(self.target, (JointDistribution, Posterior)):
|
|
180
|
+
raise ValueError('Target distribution must be a JointDistribution or Posterior.')
|
|
181
|
+
for sampler in self.samplers.values():
|
|
182
|
+
sampler.validate_target()
|
|
183
|
+
|
|
184
|
+
def sample(self, Ns, Nt=1) -> 'HybridGibbs':
|
|
185
|
+
""" Sample from the joint distribution using Gibbs sampling
|
|
186
|
+
|
|
187
|
+
Parameters
|
|
188
|
+
----------
|
|
189
|
+
Ns : int
|
|
190
|
+
The number of samples to draw.
|
|
191
|
+
Nt : int, optional, default=1
|
|
192
|
+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
|
|
193
|
+
|
|
194
|
+
"""
|
|
195
|
+
# Progress bar printing settings:
|
|
196
|
+
miniters = None if config.PROGRESS_BAR_DYNAMIC_UPDATE else Ns + 1
|
|
197
|
+
maxinterval = 10.0 if config.PROGRESS_BAR_DYNAMIC_UPDATE else float("inf")
|
|
124
198
|
|
|
125
|
-
|
|
126
|
-
|
|
199
|
+
for idx in tqdm(
|
|
200
|
+
range(Ns), "Sample: ", miniters=miniters, maxinterval=maxinterval
|
|
201
|
+
):
|
|
127
202
|
|
|
128
|
-
|
|
129
|
-
other_params = {par_name_: current_samples[par_name_] for par_name_ in par_names if par_name_ != par_name}
|
|
203
|
+
self.step()
|
|
130
204
|
|
|
131
|
-
|
|
132
|
-
|
|
205
|
+
if (Nt > 0) and (idx % Nt == 0):
|
|
206
|
+
self._store_samples()
|
|
133
207
|
|
|
134
|
-
#
|
|
135
|
-
|
|
208
|
+
# Call callback function if specified
|
|
209
|
+
self._call_callback(idx, Ns)
|
|
136
210
|
|
|
137
|
-
|
|
138
|
-
|
|
211
|
+
return self
|
|
212
|
+
|
|
213
|
+
def warmup(self, Nb, Nt=1, tune_freq=0.1) -> 'HybridGibbs':
|
|
214
|
+
""" Warmup (tune) the samplers in the Gibbs sampling scheme
|
|
215
|
+
|
|
216
|
+
Parameters
|
|
217
|
+
----------
|
|
218
|
+
Nb : int
|
|
219
|
+
The number of samples to draw during warmup.
|
|
139
220
|
|
|
140
|
-
|
|
221
|
+
Nt : int, optional, default=1
|
|
222
|
+
The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
|
|
223
|
+
|
|
224
|
+
tune_freq : float, optional
|
|
225
|
+
Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
|
|
226
|
+
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
# Progress bar printing settings:
|
|
230
|
+
miniters = None if config.PROGRESS_BAR_DYNAMIC_UPDATE else Nb + 1
|
|
231
|
+
maxinterval = 10.0 if config.PROGRESS_BAR_DYNAMIC_UPDATE else float("inf")
|
|
232
|
+
|
|
233
|
+
tune_interval = max(int(tune_freq * Nb), 1)
|
|
234
|
+
|
|
235
|
+
for idx in tqdm(
|
|
236
|
+
range(Nb), "Warmup: ", miniters=miniters, maxinterval=maxinterval
|
|
237
|
+
):
|
|
238
|
+
|
|
239
|
+
self.step()
|
|
141
240
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
241
|
+
# Tune the sampler at tuning intervals (matching behavior of Sampler class)
|
|
242
|
+
if (idx + 1) % tune_interval == 0:
|
|
243
|
+
self.tune(tune_interval, idx // tune_interval)
|
|
244
|
+
|
|
245
|
+
if (Nt > 0) and (idx % Nt == 0):
|
|
246
|
+
self._store_samples()
|
|
247
|
+
|
|
248
|
+
# Call callback function if specified
|
|
249
|
+
self._call_callback(idx, Nb)
|
|
250
|
+
|
|
251
|
+
return self
|
|
252
|
+
|
|
253
|
+
def get_samples(self) -> Dict[str, Samples]:
|
|
254
|
+
samples_object = JointSamples()
|
|
255
|
+
for par_name in self.par_names:
|
|
256
|
+
samples_array = np.array(self.samples[par_name]).T
|
|
257
|
+
samples_object[par_name] = Samples(samples_array, self.target.get_density(par_name).geometry)
|
|
258
|
+
return samples_object
|
|
259
|
+
|
|
260
|
+
def step(self):
|
|
261
|
+
""" Sequentially go through all parameters and sample them conditionally on each other """
|
|
262
|
+
|
|
263
|
+
# Sample from each conditional distribution
|
|
264
|
+
for par_name in self.scan_order:
|
|
265
|
+
|
|
266
|
+
# Set target for current parameter
|
|
267
|
+
self._set_target(par_name)
|
|
268
|
+
|
|
269
|
+
# Get sampler
|
|
270
|
+
sampler = self.samplers[par_name]
|
|
271
|
+
|
|
272
|
+
# Instead of simply changing the target of the sampler, we reinitialize it.
|
|
273
|
+
# This is to ensure that all internal variables are set to match the new target.
|
|
274
|
+
# To return the sampler to the old state and history, we first extract the state and history
|
|
275
|
+
# before reinitializing the sampler and then set the state and history back to the sampler
|
|
276
|
+
|
|
277
|
+
# Extract state and history from sampler
|
|
278
|
+
sampler_state = sampler.get_state()
|
|
279
|
+
sampler_history = sampler.get_history()
|
|
280
|
+
|
|
281
|
+
# Reinitialize sampler
|
|
282
|
+
sampler.reinitialize()
|
|
283
|
+
|
|
284
|
+
# Set state and history back to sampler
|
|
285
|
+
sampler.set_state(sampler_state)
|
|
286
|
+
sampler.set_history(sampler_history)
|
|
287
|
+
|
|
288
|
+
# Allow for multiple sampling steps in each Gibbs step
|
|
289
|
+
for _ in range(self.num_sampling_steps[par_name]):
|
|
290
|
+
# Sampling step
|
|
291
|
+
acc = sampler.step()
|
|
292
|
+
|
|
293
|
+
# Store acceptance rate in sampler (matching behavior of Sampler class Sample method)
|
|
294
|
+
sampler._acc.append(acc)
|
|
295
|
+
|
|
296
|
+
# Extract samples (Ensure even 1-dimensional samples are 1D arrays)
|
|
297
|
+
if isinstance(sampler.current_point, np.ndarray):
|
|
298
|
+
self.current_samples[par_name] = sampler.current_point.reshape(-1)
|
|
299
|
+
else:
|
|
300
|
+
self.current_samples[par_name] = sampler.current_point
|
|
301
|
+
|
|
302
|
+
def tune(self, skip_len, update_count):
|
|
303
|
+
""" Run a single tuning step on each of the samplers in the Gibbs sampling scheme
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
skip_len : int
|
|
308
|
+
Defines the number of steps in between tuning (i.e. the tuning interval).
|
|
309
|
+
|
|
310
|
+
update_count : int
|
|
311
|
+
The number of times tuning has been performed. Can be used for internal bookkeeping.
|
|
312
|
+
|
|
313
|
+
"""
|
|
314
|
+
for par_name in self.par_names:
|
|
315
|
+
self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
|
|
146
316
|
|
|
147
317
|
# ------------ Private methods ------------
|
|
148
|
-
def
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
|
|
318
|
+
def _call_callback(self, sample_index, num_of_samples):
|
|
319
|
+
""" Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
|
|
320
|
+
if self.callback is not None:
|
|
321
|
+
self.callback(self, sample_index, num_of_samples)
|
|
322
|
+
|
|
323
|
+
def _initialize_samplers(self):
|
|
324
|
+
""" Initialize samplers """
|
|
325
|
+
for sampler in self.samplers.values():
|
|
326
|
+
sampler.initialize()
|
|
327
|
+
|
|
328
|
+
def _initialize_num_sampling_steps(self):
|
|
329
|
+
""" Initialize the number of sampling steps for each sampler. Defaults to 1 if not set by user """
|
|
330
|
+
|
|
331
|
+
if self.num_sampling_steps is None:
|
|
332
|
+
self.num_sampling_steps = {par_name: 1 for par_name in self.par_names}
|
|
333
|
+
|
|
152
334
|
for par_name in self.par_names:
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
# Store samples in self
|
|
156
|
-
if hasattr(self, 'samples'):
|
|
157
|
-
# Append to existing samples (This makes a copy)
|
|
158
|
-
for par_name in self.par_names:
|
|
159
|
-
samples[par_name] = np.hstack((self.samples[par_name], samples[par_name]))
|
|
160
|
-
self.samples = samples
|
|
335
|
+
if par_name not in self.num_sampling_steps:
|
|
336
|
+
self.num_sampling_steps[par_name] = 1
|
|
161
337
|
|
|
162
|
-
def
|
|
163
|
-
"""
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
338
|
+
def _set_targets(self):
|
|
339
|
+
""" Set targets for all samplers using the current samples """
|
|
340
|
+
par_names = self.par_names
|
|
341
|
+
for par_name in par_names:
|
|
342
|
+
self._set_target(par_name)
|
|
343
|
+
|
|
344
|
+
def _set_target(self, par_name):
|
|
345
|
+
""" Set target conditional distribution for a single parameter using the current samples """
|
|
346
|
+
# Get all other conditional parameters other than the current parameter and update the target
|
|
347
|
+
# This defines - from a joint p(x,y,z) - the conditional distribution p(x|y,z) or p(y|x,z) or p(z|x,y)
|
|
348
|
+
conditional_params = {par_name_: self.current_samples[par_name_] for par_name_ in self.par_names if par_name_ != par_name}
|
|
349
|
+
self.samplers[par_name].target = self.target(**conditional_params)
|
|
168
350
|
|
|
169
|
-
|
|
351
|
+
def _allocate_samples(self):
|
|
352
|
+
""" Allocate memory for samples """
|
|
170
353
|
samples = {}
|
|
171
354
|
for par_name in self.par_names:
|
|
172
|
-
samples[par_name] =
|
|
173
|
-
self.
|
|
355
|
+
samples[par_name] = []
|
|
356
|
+
self.samples = samples
|
|
174
357
|
|
|
175
358
|
def _get_initial_points(self):
|
|
176
359
|
""" Get initial points for each parameter """
|
|
177
360
|
initial_points = {}
|
|
178
361
|
for par_name in self.par_names:
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
initial_points[par_name] = self.target.get_density(par_name).init_point
|
|
185
|
-
else:
|
|
186
|
-
initial_points[par_name] = np.ones(self.target.get_density(par_name).dim)
|
|
362
|
+
sampler = self.samplers[par_name]
|
|
363
|
+
if sampler.initial_point is None:
|
|
364
|
+
sampler.initial_point = sampler._get_default_initial_point(self.target.get_density(par_name).dim)
|
|
365
|
+
initial_points[par_name] = sampler.initial_point
|
|
366
|
+
|
|
187
367
|
return initial_points
|
|
188
368
|
|
|
189
|
-
def _store_samples(self
|
|
369
|
+
def _store_samples(self):
|
|
190
370
|
""" Store current samples at index i of samples dict """
|
|
191
371
|
for par_name in self.par_names:
|
|
192
|
-
samples[par_name]
|
|
372
|
+
self.samples[par_name].append(self.current_samples[par_name])
|
|
193
373
|
|
|
194
|
-
def
|
|
195
|
-
"""
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
return samples_object
|
|
200
|
-
|
|
201
|
-
def _print_progress(self, s, Ns, phase):
|
|
202
|
-
"""Prints sampling progress"""
|
|
203
|
-
if Ns < 2: # Don't print progress if only one sample
|
|
204
|
-
return
|
|
205
|
-
if (s % (max(Ns//100,1))) == 0:
|
|
206
|
-
msg = f'{phase} {s} / {Ns}'
|
|
207
|
-
sys.stdout.write('\r'+msg)
|
|
208
|
-
if s==Ns:
|
|
209
|
-
msg = f'{phase} {s} / {Ns}'
|
|
210
|
-
sys.stdout.write('\r'+msg+'\n')
|
|
211
|
-
|
|
212
|
-
# ------------ Private properties ------------
|
|
213
|
-
@property
|
|
214
|
-
def _Ns(self):
|
|
215
|
-
""" Number of samples already taken """
|
|
216
|
-
if hasattr(self, 'samples'):
|
|
217
|
-
return self.samples[self.par_names[0]].shape[-1]
|
|
374
|
+
def __repr__(self):
|
|
375
|
+
""" Return a string representation of the sampler. """
|
|
376
|
+
msg = f"Sampler: {self.__class__.__name__} \n"
|
|
377
|
+
if self.target is None:
|
|
378
|
+
msg += f" Target: None \n"
|
|
218
379
|
else:
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
return self.samples_warmup[self.par_names[0]].shape[-1]
|
|
226
|
-
else:
|
|
227
|
-
return 0
|
|
380
|
+
msg += f" Target: \n \t {self.target} \n\n"
|
|
381
|
+
|
|
382
|
+
for key, value in zip(self.samplers.keys(), self.samplers.values()):
|
|
383
|
+
msg += f" Variable '{key}' with {value} \n"
|
|
384
|
+
|
|
385
|
+
return msg
|