CUQIpy 1.3.0.post0.dev298__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/_distribution.py +24 -15
  5. cuqi/distribution/_joint_distribution.py +96 -11
  6. cuqi/distribution/_posterior.py +9 -0
  7. cuqi/experimental/__init__.py +1 -2
  8. cuqi/experimental/_recommender.py +4 -4
  9. cuqi/implicitprior/__init__.py +1 -1
  10. cuqi/implicitprior/_restorator.py +35 -1
  11. cuqi/legacy/__init__.py +2 -0
  12. cuqi/legacy/sampler/__init__.py +11 -0
  13. cuqi/legacy/sampler/_conjugate.py +55 -0
  14. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  15. cuqi/legacy/sampler/_cwmh.py +196 -0
  16. cuqi/legacy/sampler/_gibbs.py +231 -0
  17. cuqi/legacy/sampler/_hmc.py +335 -0
  18. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  19. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  20. cuqi/legacy/sampler/_mh.py +190 -0
  21. cuqi/legacy/sampler/_pcn.py +244 -0
  22. cuqi/legacy/sampler/_rto.py +284 -0
  23. cuqi/legacy/sampler/_sampler.py +182 -0
  24. cuqi/likelihood/_likelihood.py +1 -1
  25. cuqi/model/_model.py +212 -77
  26. cuqi/pde/__init__.py +4 -0
  27. cuqi/pde/_observation_map.py +36 -0
  28. cuqi/pde/_pde.py +52 -21
  29. cuqi/problem/_problem.py +87 -80
  30. cuqi/sampler/__init__.py +120 -8
  31. cuqi/sampler/_conjugate.py +376 -35
  32. cuqi/sampler/_conjugate_approx.py +40 -16
  33. cuqi/sampler/_cwmh.py +132 -138
  34. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  35. cuqi/sampler/_gibbs.py +269 -130
  36. cuqi/sampler/_hmc.py +328 -201
  37. cuqi/sampler/_langevin_algorithm.py +282 -98
  38. cuqi/sampler/_laplace_approximation.py +87 -117
  39. cuqi/sampler/_mh.py +47 -157
  40. cuqi/sampler/_pcn.py +56 -211
  41. cuqi/sampler/_rto.py +206 -140
  42. cuqi/sampler/_sampler.py +540 -135
  43. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +1 -1
  44. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/RECORD +47 -45
  45. cuqi/experimental/mcmc/__init__.py +0 -122
  46. cuqi/experimental/mcmc/_conjugate.py +0 -396
  47. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  48. cuqi/experimental/mcmc/_cwmh.py +0 -190
  49. cuqi/experimental/mcmc/_gibbs.py +0 -374
  50. cuqi/experimental/mcmc/_hmc.py +0 -460
  51. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
  52. cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
  53. cuqi/experimental/mcmc/_mh.py +0 -80
  54. cuqi/experimental/mcmc/_pcn.py +0 -89
  55. cuqi/experimental/mcmc/_rto.py +0 -306
  56. cuqi/experimental/mcmc/_sampler.py +0 -564
  57. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +0 -0
  58. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/licenses/LICENSE +0 -0
  59. {cuqipy-1.3.0.post0.dev298.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
@@ -1,374 +0,0 @@
1
- from cuqi.distribution import JointDistribution
2
- from cuqi.experimental.mcmc import Sampler
3
- from cuqi.samples import Samples, JointSamples
4
- from cuqi.experimental.mcmc import NUTS
5
- from typing import Dict
6
- import numpy as np
7
- import warnings
8
-
9
- try:
10
- from tqdm import tqdm
11
- except ImportError:
12
- def tqdm(iterable, **kwargs):
13
- warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
14
- return iterable
15
-
16
- # Not subclassed from Sampler as Gibbs handles multiple samplers and samples multiple parameters
17
- # Similar approach as for JointDistribution
18
- class HybridGibbs:
19
- """
20
- Hybrid Gibbs sampler for sampling a joint distribution.
21
-
22
- Gibbs sampling samples the variables of the distribution sequentially,
23
- one variable at a time. When a variable represents a random vector, the
24
- whole vector is sampled simultaneously.
25
-
26
- The sampling of each variable is done by sampling from the conditional
27
- distribution of that variable given the values of the other variables.
28
- This is often a very efficient way of sampling from a joint distribution
29
- if the conditional distributions are easy to sample from.
30
-
31
- Hybrid Gibbs sampler is a generalization of the Gibbs sampler where the
32
- conditional distributions are sampled using different MCMC samplers.
33
-
34
- When the conditionals are sampled exactly, the samples from the Gibbs
35
- sampler converge to the joint distribution. See e.g.
36
- Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
37
- for more details.
38
-
39
- In each Gibbs step, the corresponding sampler has the initial_point
40
- and initial_scale (if applicable) set to the value of the previous step
41
- and the sampler is reinitialized. This means that the sampling is not
42
- fully stateful at this point. This means samplers like NUTS will lose
43
- their internal state between Gibbs steps.
44
-
45
- The order in which the conditionals are sampled is the order of the
46
- variables in the sampling strategy, unless a different sampling order
47
- is specified by the parameter `scan_order`
48
-
49
- Parameters
50
- ----------
51
- target : cuqi.distribution.JointDistribution
52
- Target distribution to sample from.
53
-
54
- sampling_strategy : dict
55
- Dictionary of sampling strategies for each variable.
56
- Keys are variable names.
57
- Values are sampler objects.
58
-
59
- num_sampling_steps : dict, *optional*
60
- Dictionary of number of sampling steps for each variable.
61
- The sampling steps are defined as the number of times the sampler
62
- will call its step method in each Gibbs step.
63
- Default is 1 for all variables.
64
-
65
- scan_order : list or str, *optional*
66
- Order in which the conditional distributions are sampled.
67
- If set to "random", use a random ordering at each step.
68
- If not specified, it will be the order in the sampling_strategy.
69
-
70
- callback : callable, optional
71
- A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
72
- The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
73
-
74
- Example
75
- -------
76
- .. code-block:: python
77
-
78
- import cuqi
79
- import numpy as np
80
-
81
- # Model and data
82
- A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='sinc').get_components()
83
- n = A.domain_dim
84
-
85
- # Define distributions
86
- d = cuqi.distribution.Gamma(1, 1e-4)
87
- l = cuqi.distribution.Gamma(1, 1e-4)
88
- x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
89
- y = cuqi.distribution.Gaussian(A, lambda l: 1/l)
90
-
91
- # Combine into a joint distribution and create posterior
92
- joint = cuqi.distribution.JointDistribution(d, l, x, y)
93
- posterior = joint(y=y_obs)
94
-
95
- # Define sampling strategy
96
- sampling_strategy = {
97
- 'x': cuqi.experimental.mcmc.LinearRTO(maxit=15),
98
- 'd': cuqi.experimental.mcmc.Conjugate(),
99
- 'l': cuqi.experimental.mcmc.Conjugate(),
100
- }
101
-
102
- # Define Gibbs sampler
103
- sampler = cuqi.experimental.mcmc.HybridGibbs(posterior, sampling_strategy)
104
-
105
- # Run sampler
106
- sampler.warmup(200)
107
- sampler.sample(1000)
108
-
109
- # Get samples removing burn-in
110
- samples = sampler.get_samples().burnthin(200)
111
-
112
- # Plot results
113
- samples['x'].plot_ci(exact=probinfo.exactSolution)
114
- samples['d'].plot_trace(figsize=(8,2))
115
- samples['l'].plot_trace(figsize=(8,2))
116
-
117
- """
118
-
119
- def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
120
-
121
- # Store target and allow conditioning to reduce to a single density
122
- self.target = target() # Create a copy of target distribution (to avoid modifying the original)
123
-
124
- # Store sampler instances (again as a copy to avoid modifying the original)
125
- self.samplers = sampling_strategy.copy()
126
-
127
- # Store number of sampling steps for each parameter
128
- self.num_sampling_steps = num_sampling_steps
129
-
130
- # Store parameter names
131
- self.par_names = self.target.get_parameter_names()
132
-
133
- # Store the scan order
134
- self._scan_order = scan_order
135
-
136
- # Check that the parameters of the target align with the sampling_strategy and scan_order
137
- if set(self.par_names) != set(self.scan_order):
138
- raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
139
-
140
- # Initialize sampler (after target is set)
141
- self._initialize()
142
-
143
- # Set the callback function
144
- self.callback = callback
145
-
146
- def _initialize(self):
147
- """ Initialize sampler """
148
-
149
- # Initial points
150
- self.current_samples = self._get_initial_points()
151
-
152
- # Initialize sampling steps
153
- self._initialize_num_sampling_steps()
154
-
155
- # Allocate samples
156
- self._allocate_samples()
157
-
158
- # Set targets
159
- self._set_targets()
160
-
161
- # Initialize the samplers
162
- self._initialize_samplers()
163
-
164
- # Validate all targets for samplers.
165
- self.validate_targets()
166
-
167
- @property
168
- def scan_order(self):
169
- if self._scan_order is None:
170
- return list(self.samplers.keys())
171
- if self._scan_order == "random":
172
- arr = list(self.samplers.keys())
173
- np.random.shuffle(arr) # Shuffle works in-place
174
- return arr
175
- return self._scan_order
176
-
177
- # ------------ Public methods ------------
178
- def validate_targets(self):
179
- """ Validate each of the conditional targets used in the Gibbs steps """
180
- if not isinstance(self.target, JointDistribution):
181
- raise ValueError('Target distribution must be a JointDistribution.')
182
- for sampler in self.samplers.values():
183
- sampler.validate_target()
184
-
185
- def sample(self, Ns) -> 'HybridGibbs':
186
- """ Sample from the joint distribution using Gibbs sampling
187
-
188
- Parameters
189
- ----------
190
- Ns : int
191
- The number of samples to draw.
192
-
193
- """
194
- for idx in tqdm(range(Ns), "Sample: "):
195
-
196
- self.step()
197
-
198
- self._store_samples()
199
-
200
- # Call callback function if specified
201
- self._call_callback(idx, Ns)
202
-
203
- return self
204
-
205
- def warmup(self, Nb, tune_freq=0.1) -> 'HybridGibbs':
206
- """ Warmup (tune) the samplers in the Gibbs sampling scheme
207
-
208
- Parameters
209
- ----------
210
- Nb : int
211
- The number of samples to draw during warmup.
212
-
213
- tune_freq : float, optional
214
- Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
215
-
216
- """
217
-
218
- tune_interval = max(int(tune_freq * Nb), 1)
219
-
220
- for idx in tqdm(range(Nb), "Warmup: "):
221
-
222
- self.step()
223
-
224
- # Tune the sampler at tuning intervals (matching behavior of Sampler class)
225
- if (idx + 1) % tune_interval == 0:
226
- self.tune(tune_interval, idx // tune_interval)
227
-
228
- self._store_samples()
229
-
230
- # Call callback function if specified
231
- self._call_callback(idx, Nb)
232
-
233
- return self
234
-
235
- def get_samples(self) -> Dict[str, Samples]:
236
- samples_object = JointSamples()
237
- for par_name in self.par_names:
238
- samples_array = np.array(self.samples[par_name]).T
239
- samples_object[par_name] = Samples(samples_array, self.target.get_density(par_name).geometry)
240
- return samples_object
241
-
242
- def step(self):
243
- """ Sequentially go through all parameters and sample them conditionally on each other """
244
-
245
- # Sample from each conditional distribution
246
- for par_name in self.scan_order:
247
-
248
- # Set target for current parameter
249
- self._set_target(par_name)
250
-
251
- # Get sampler
252
- sampler = self.samplers[par_name]
253
-
254
- # Instead of simply changing the target of the sampler, we reinitialize it.
255
- # This is to ensure that all internal variables are set to match the new target.
256
- # To return the sampler to the old state and history, we first extract the state and history
257
- # before reinitializing the sampler and then set the state and history back to the sampler
258
-
259
- # Extract state and history from sampler
260
- if isinstance(sampler, NUTS): # Special case for NUTS as it is not playing nice with get_state and get_history
261
- sampler.initial_point = sampler.current_point
262
- else:
263
- sampler_state = sampler.get_state()
264
- sampler_history = sampler.get_history()
265
-
266
- # Reinitialize sampler
267
- sampler.reinitialize()
268
-
269
- # Set state and history back to sampler
270
- if not isinstance(sampler, NUTS): # Again, special case for NUTS.
271
- sampler.set_state(sampler_state)
272
- sampler.set_history(sampler_history)
273
-
274
- # Allow for multiple sampling steps in each Gibbs step
275
- for _ in range(self.num_sampling_steps[par_name]):
276
- # Sampling step
277
- acc = sampler.step()
278
-
279
- # Store acceptance rate in sampler (matching behavior of Sampler class Sample method)
280
- sampler._acc.append(acc)
281
-
282
- # Extract samples (Ensure even 1-dimensional samples are 1D arrays)
283
- if isinstance(sampler.current_point, np.ndarray):
284
- self.current_samples[par_name] = sampler.current_point.reshape(-1)
285
- else:
286
- self.current_samples[par_name] = sampler.current_point
287
-
288
- def tune(self, skip_len, update_count):
289
- """ Run a single tuning step on each of the samplers in the Gibbs sampling scheme
290
-
291
- Parameters
292
- ----------
293
- skip_len : int
294
- Defines the number of steps in between tuning (i.e. the tuning interval).
295
-
296
- update_count : int
297
- The number of times tuning has been performed. Can be used for internal bookkeeping.
298
-
299
- """
300
- for par_name in self.par_names:
301
- self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
302
-
303
- # ------------ Private methods ------------
304
- def _call_callback(self, sample_index, num_of_samples):
305
- """ Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
306
- if self.callback is not None:
307
- self.callback(self, sample_index, num_of_samples)
308
-
309
- def _initialize_samplers(self):
310
- """ Initialize samplers """
311
- for sampler in self.samplers.values():
312
- if isinstance(sampler, NUTS):
313
- print(f'Warning: NUTS sampler is not fully stateful in HybridGibbs. Sampler will be reinitialized in each Gibbs step.')
314
- sampler.initialize()
315
-
316
- def _initialize_num_sampling_steps(self):
317
- """ Initialize the number of sampling steps for each sampler. Defaults to 1 if not set by user """
318
-
319
- if self.num_sampling_steps is None:
320
- self.num_sampling_steps = {par_name: 1 for par_name in self.par_names}
321
-
322
- for par_name in self.par_names:
323
- if par_name not in self.num_sampling_steps:
324
- self.num_sampling_steps[par_name] = 1
325
-
326
-
327
- def _set_targets(self):
328
- """ Set targets for all samplers using the current samples """
329
- par_names = self.par_names
330
- for par_name in par_names:
331
- self._set_target(par_name)
332
-
333
- def _set_target(self, par_name):
334
- """ Set target conditional distribution for a single parameter using the current samples """
335
- # Get all other conditional parameters other than the current parameter and update the target
336
- # This defines - from a joint p(x,y,z) - the conditional distribution p(x|y,z) or p(y|x,z) or p(z|x,y)
337
- conditional_params = {par_name_: self.current_samples[par_name_] for par_name_ in self.par_names if par_name_ != par_name}
338
- self.samplers[par_name].target = self.target(**conditional_params)
339
-
340
- def _allocate_samples(self):
341
- """ Allocate memory for samples """
342
- samples = {}
343
- for par_name in self.par_names:
344
- samples[par_name] = []
345
- self.samples = samples
346
-
347
- def _get_initial_points(self):
348
- """ Get initial points for each parameter """
349
- initial_points = {}
350
- for par_name in self.par_names:
351
- sampler = self.samplers[par_name]
352
- if sampler.initial_point is None:
353
- sampler.initial_point = sampler._get_default_initial_point(self.target.get_density(par_name).dim)
354
- initial_points[par_name] = sampler.initial_point
355
-
356
- return initial_points
357
-
358
- def _store_samples(self):
359
- """ Store current samples at index i of samples dict """
360
- for par_name in self.par_names:
361
- self.samples[par_name].append(self.current_samples[par_name])
362
-
363
- def __repr__(self):
364
- """ Return a string representation of the sampler. """
365
- msg = f"Sampler: {self.__class__.__name__} \n"
366
- if self.target is None:
367
- msg += f" Target: None \n"
368
- else:
369
- msg += f" Target: \n \t {self.target} \n\n"
370
-
371
- for key, value in zip(self.samplers.keys(), self.samplers.values()):
372
- msg += f" Variable '{key}' with {value} \n"
373
-
374
- return msg