CUQIpy 1.1.1.post0.dev36__py3-none-any.whl → 1.4.1.post0.dev124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (92) hide show
  1. cuqi/__init__.py +2 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/algebra/__init__.py +2 -0
  4. cuqi/algebra/_abstract_syntax_tree.py +358 -0
  5. cuqi/algebra/_ordered_set.py +82 -0
  6. cuqi/algebra/_random_variable.py +457 -0
  7. cuqi/array/_array.py +4 -13
  8. cuqi/config.py +7 -0
  9. cuqi/density/_density.py +9 -1
  10. cuqi/distribution/__init__.py +3 -2
  11. cuqi/distribution/_beta.py +7 -11
  12. cuqi/distribution/_cauchy.py +2 -2
  13. cuqi/distribution/_custom.py +0 -6
  14. cuqi/distribution/_distribution.py +31 -45
  15. cuqi/distribution/_gamma.py +7 -3
  16. cuqi/distribution/_gaussian.py +2 -12
  17. cuqi/distribution/_inverse_gamma.py +4 -10
  18. cuqi/distribution/_joint_distribution.py +112 -15
  19. cuqi/distribution/_lognormal.py +0 -7
  20. cuqi/distribution/{_modifiedhalfnormal.py → _modified_half_normal.py} +23 -23
  21. cuqi/distribution/_normal.py +34 -7
  22. cuqi/distribution/_posterior.py +9 -0
  23. cuqi/distribution/_truncated_normal.py +129 -0
  24. cuqi/distribution/_uniform.py +47 -1
  25. cuqi/experimental/__init__.py +2 -2
  26. cuqi/experimental/_recommender.py +216 -0
  27. cuqi/geometry/__init__.py +2 -0
  28. cuqi/geometry/_geometry.py +15 -1
  29. cuqi/geometry/_product_geometry.py +181 -0
  30. cuqi/implicitprior/__init__.py +5 -3
  31. cuqi/implicitprior/_regularized_gaussian.py +483 -0
  32. cuqi/implicitprior/{_regularizedGMRF.py → _regularized_gmrf.py} +4 -2
  33. cuqi/implicitprior/{_regularizedUnboundedUniform.py → _regularized_unbounded_uniform.py} +3 -2
  34. cuqi/implicitprior/_restorator.py +269 -0
  35. cuqi/legacy/__init__.py +2 -0
  36. cuqi/{experimental/mcmc → legacy/sampler}/__init__.py +7 -11
  37. cuqi/legacy/sampler/_conjugate.py +55 -0
  38. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  39. cuqi/legacy/sampler/_cwmh.py +196 -0
  40. cuqi/legacy/sampler/_gibbs.py +231 -0
  41. cuqi/legacy/sampler/_hmc.py +335 -0
  42. cuqi/{experimental/mcmc → legacy/sampler}/_langevin_algorithm.py +82 -111
  43. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  44. cuqi/legacy/sampler/_mh.py +190 -0
  45. cuqi/legacy/sampler/_pcn.py +244 -0
  46. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +132 -90
  47. cuqi/legacy/sampler/_sampler.py +182 -0
  48. cuqi/likelihood/_likelihood.py +9 -1
  49. cuqi/model/__init__.py +1 -1
  50. cuqi/model/_model.py +1361 -359
  51. cuqi/pde/__init__.py +4 -0
  52. cuqi/pde/_observation_map.py +36 -0
  53. cuqi/pde/_pde.py +134 -33
  54. cuqi/problem/_problem.py +93 -87
  55. cuqi/sampler/__init__.py +120 -8
  56. cuqi/sampler/_conjugate.py +376 -35
  57. cuqi/sampler/_conjugate_approx.py +40 -16
  58. cuqi/sampler/_cwmh.py +132 -138
  59. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  60. cuqi/sampler/_gibbs.py +288 -130
  61. cuqi/sampler/_hmc.py +328 -201
  62. cuqi/sampler/_langevin_algorithm.py +284 -100
  63. cuqi/sampler/_laplace_approximation.py +87 -117
  64. cuqi/sampler/_mh.py +47 -157
  65. cuqi/sampler/_pcn.py +65 -213
  66. cuqi/sampler/_rto.py +211 -142
  67. cuqi/sampler/_sampler.py +553 -136
  68. cuqi/samples/__init__.py +1 -1
  69. cuqi/samples/_samples.py +24 -18
  70. cuqi/solver/__init__.py +6 -4
  71. cuqi/solver/_solver.py +230 -26
  72. cuqi/testproblem/_testproblem.py +2 -3
  73. cuqi/utilities/__init__.py +6 -1
  74. cuqi/utilities/_get_python_variable_name.py +2 -2
  75. cuqi/utilities/_utilities.py +182 -2
  76. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/METADATA +10 -6
  77. cuqipy-1.4.1.post0.dev124.dist-info/RECORD +101 -0
  78. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/WHEEL +1 -1
  79. CUQIpy-1.1.1.post0.dev36.dist-info/RECORD +0 -92
  80. cuqi/experimental/mcmc/_conjugate.py +0 -197
  81. cuqi/experimental/mcmc/_conjugate_approx.py +0 -81
  82. cuqi/experimental/mcmc/_cwmh.py +0 -191
  83. cuqi/experimental/mcmc/_gibbs.py +0 -268
  84. cuqi/experimental/mcmc/_hmc.py +0 -470
  85. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  86. cuqi/experimental/mcmc/_mh.py +0 -78
  87. cuqi/experimental/mcmc/_pcn.py +0 -89
  88. cuqi/experimental/mcmc/_sampler.py +0 -561
  89. cuqi/experimental/mcmc/_utilities.py +0 -17
  90. cuqi/implicitprior/_regularizedGaussian.py +0 -323
  91. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info/licenses}/LICENSE +0 -0
  92. {CUQIpy-1.1.1.post0.dev36.dist-info → cuqipy-1.4.1.post0.dev124.dist-info}/top_level.txt +0 -0
cuqi/sampler/_gibbs.py CHANGED
@@ -1,14 +1,23 @@
1
- from cuqi.distribution import JointDistribution
1
+ from cuqi.distribution import JointDistribution, Posterior
2
2
  from cuqi.sampler import Sampler
3
- from cuqi.samples import Samples
4
- from typing import Dict, Union
3
+ from cuqi.samples import Samples, JointSamples
4
+ from typing import Dict
5
5
  import numpy as np
6
- import sys
7
-
8
-
9
- class Gibbs:
6
+ import warnings
7
+ from cuqi import config
8
+
9
+ try:
10
+ from tqdm import tqdm
11
+ except ImportError:
12
+ def tqdm(iterable, **kwargs):
13
+ warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
14
+ return iterable
15
+
16
+ # Not subclassed from Sampler as Gibbs handles multiple samplers and samples multiple parameters
17
+ # Similar approach as for JointDistribution
18
+ class HybridGibbs:
10
19
  """
11
- Gibbs sampler for sampling a joint distribution.
20
+ Hybrid Gibbs sampler for sampling a joint distribution.
12
21
 
13
22
  Gibbs sampling samples the variables of the distribution sequentially,
14
23
  one variable at a time. When a variable represents a random vector, the
@@ -17,7 +26,24 @@ class Gibbs:
17
26
  The sampling of each variable is done by sampling from the conditional
18
27
  distribution of that variable given the values of the other variables.
19
28
  This is often a very efficient way of sampling from a joint distribution
20
- if the conditional distributions are easy to sample from.
29
+ if the conditional distributions are easy to sample from.
30
+
31
+ Hybrid Gibbs sampler is a generalization of the Gibbs sampler where the
32
+ conditional distributions are sampled using different MCMC samplers.
33
+
34
+ When the conditionals are sampled exactly, the samples from the Gibbs
35
+ sampler converge to the joint distribution. See e.g.
36
+ Gelman et al. "Bayesian Data Analysis" (2014), Third Edition
37
+ for more details.
38
+
39
+ In each Gibbs step, the corresponding sampler state and history are stored,
40
+ then the sampler is reinitialized. After reinitialization, the sampler state
41
+ and history are set back to the stored values. This ensures preserving the
42
+ statefulness of the samplers.
43
+
44
+ The order in which the conditionals are sampled is the order of the
45
+ variables in the sampling strategy, unless a different sampling order
46
+ is specified by the parameter `scan_order`
21
47
 
22
48
  Parameters
23
49
  ----------
@@ -25,10 +51,25 @@ class Gibbs:
25
51
  Target distribution to sample from.
26
52
 
27
53
  sampling_strategy : dict
28
- Dictionary of sampling strategies for each parameter.
29
- Keys are parameter names.
54
+ Dictionary of sampling strategies for each variable.
55
+ Keys are variable names.
30
56
  Values are sampler objects.
31
57
 
58
+ num_sampling_steps : dict, *optional*
59
+ Dictionary of number of sampling steps for each variable.
60
+ The sampling steps are defined as the number of times the sampler
61
+ will call its step method in each Gibbs step.
62
+ Default is 1 for all variables.
63
+
64
+ scan_order : list or str, *optional*
65
+ Order in which the conditional distributions are sampled.
66
+ If set to "random", use a random ordering at each step.
67
+ If not specified, it will be the order in the sampling_strategy.
68
+
69
+ callback : callable, optional
70
+ A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
71
+ The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
72
+
32
73
  Example
33
74
  -------
34
75
  .. code-block:: python
@@ -37,7 +78,7 @@ class Gibbs:
37
78
  import numpy as np
38
79
 
39
80
  # Model and data
40
- A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='square').get_components()
81
+ A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='sinc').get_components()
41
82
  n = A.domain_dim
42
83
 
43
84
  # Define distributions
@@ -52,15 +93,20 @@ class Gibbs:
52
93
 
53
94
  # Define sampling strategy
54
95
  sampling_strategy = {
55
- 'x': cuqi.sampler.LinearRTO,
56
- ('d', 'l'): cuqi.sampler.Conjugate,
96
+ 'x': cuqi.sampler.LinearRTO(maxit=15),
97
+ 'd': cuqi.sampler.Conjugate(),
98
+ 'l': cuqi.sampler.Conjugate(),
57
99
  }
58
100
 
59
101
  # Define Gibbs sampler
60
- sampler = cuqi.sampler.Gibbs(posterior, sampling_strategy)
102
+ sampler = cuqi.sampler.HybridGibbs(posterior, sampling_strategy)
61
103
 
62
104
  # Run sampler
63
- samples = sampler.sample(Ns=1000, Nb=200)
105
+ sampler.warmup(200)
106
+ sampler.sample(1000)
107
+
108
+ # Get samples removing burn-in
109
+ samples = sampler.get_samples().burnthin(200)
64
110
 
65
111
  # Plot results
66
112
  samples['x'].plot_ci(exact=probinfo.exactSolution)
@@ -69,159 +115,271 @@ class Gibbs:
69
115
 
70
116
  """
71
117
 
72
- def __init__(self, target: JointDistribution, sampling_strategy: Dict[Union[str,tuple], Sampler]):
118
+ def __init__(self, target: JointDistribution, sampling_strategy: Dict[str, Sampler], num_sampling_steps: Dict[str, int] = None, scan_order = None, callback=None):
73
119
 
74
120
  # Store target and allow conditioning to reduce to a single density
75
121
  self.target = target() # Create a copy of target distribution (to avoid modifying the original)
76
122
 
77
- # Parse samplers and split any keys that are tuple into separate keys
78
- self.samplers = {}
79
- for par_name in sampling_strategy.keys():
80
- if isinstance(par_name, tuple):
81
- for par_name_ in par_name:
82
- self.samplers[par_name_] = sampling_strategy[par_name]
83
- else:
84
- self.samplers[par_name] = sampling_strategy[par_name]
123
+ # Store sampler instances (again as a copy to avoid modifying the original)
124
+ self.samplers = sampling_strategy.copy()
125
+
126
+ # Store number of sampling steps for each parameter
127
+ self.num_sampling_steps = num_sampling_steps
85
128
 
86
129
  # Store parameter names
87
130
  self.par_names = self.target.get_parameter_names()
88
131
 
89
- # ------------ Public methods ------------
90
- def sample(self, Ns, Nb=0):
91
- """ Sample from target distribution """
132
+ # Store the scan order
133
+ self._scan_order = scan_order
134
+
135
+ # Check that the parameters of the target align with the sampling_strategy and scan_order
136
+ if set(self.par_names) != set(self.scan_order):
137
+ raise ValueError("Parameter names in JointDistribution do not equal the names in the scan order.")
138
+
139
+ # Initialize sampler (after target is set)
140
+ self._initialize()
141
+
142
+ # Set the callback function
143
+ self.callback = callback
144
+
145
+ def _initialize(self):
146
+ """ Initialize sampler """
92
147
 
93
148
  # Initial points
94
- current_samples = self._get_initial_points()
149
+ self.current_samples = self._get_initial_points()
95
150
 
96
- # Compute how many samples were already taken previously
97
- at_Nb = self._Nb
98
- at_Ns = self._Ns
151
+ # Initialize sampling steps
152
+ self._initialize_num_sampling_steps()
99
153
 
100
- # Allocate memory for samples
101
- self._allocate_samples_warmup(Nb)
102
- self._allocate_samples(Ns)
154
+ # Allocate samples
155
+ self._allocate_samples()
103
156
 
104
- # Sample tuning phase
105
- for i in range(at_Nb, at_Nb+Nb):
106
- current_samples = self.step_tune(current_samples)
107
- self._store_samples(self.samples_warmup, current_samples, i)
108
- self._print_progress(i+1+at_Nb, at_Nb+Nb, 'Warmup')
157
+ # Set targets
158
+ self._set_targets()
109
159
 
110
- # Sample phase
111
- for i in range(at_Ns, at_Ns+Ns):
112
- current_samples = self.step(current_samples)
113
- self._store_samples(self.samples, current_samples, i)
114
- self._print_progress(i+1, at_Ns+Ns, 'Sample')
160
+ # Initialize the samplers
161
+ self._initialize_samplers()
115
162
 
116
- # Convert to samples objects and return
117
- return self._convert_to_Samples(self.samples)
163
+ # Validate all targets for samplers.
164
+ self.validate_targets()
118
165
 
119
- def step(self, current_samples):
120
- """ Sequentially go through all parameters and sample them conditionally on each other """
166
+ @property
167
+ def scan_order(self):
168
+ if self._scan_order is None:
169
+ return list(self.samplers.keys())
170
+ if self._scan_order == "random":
171
+ arr = list(self.samplers.keys())
172
+ np.random.shuffle(arr) # Shuffle works in-place
173
+ return arr
174
+ return self._scan_order
121
175
 
122
- # Extract par names
123
- par_names = self.par_names
176
+ # ------------ Public methods ------------
177
+ def validate_targets(self):
178
+ """ Validate each of the conditional targets used in the Gibbs steps """
179
+ if not isinstance(self.target, (JointDistribution, Posterior)):
180
+ raise ValueError('Target distribution must be a JointDistribution or Posterior.')
181
+ for sampler in self.samplers.values():
182
+ sampler.validate_target()
183
+
184
+ def sample(self, Ns, Nt=1) -> 'HybridGibbs':
185
+ """ Sample from the joint distribution using Gibbs sampling
186
+
187
+ Parameters
188
+ ----------
189
+ Ns : int
190
+ The number of samples to draw.
191
+ Nt : int, optional, default=1
192
+ The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
193
+
194
+ """
195
+ # Progress bar printing settings:
196
+ miniters = None if config.PROGRESS_BAR_DYNAMIC_UPDATE else Ns + 1
197
+ maxinterval = 10.0 if config.PROGRESS_BAR_DYNAMIC_UPDATE else float("inf")
124
198
 
125
- # Sample from each conditional distribution
126
- for par_name in par_names:
199
+ for idx in tqdm(
200
+ range(Ns), "Sample: ", miniters=miniters, maxinterval=maxinterval
201
+ ):
127
202
 
128
- # Dict of all other parameters to condition on
129
- other_params = {par_name_: current_samples[par_name_] for par_name_ in par_names if par_name_ != par_name}
203
+ self.step()
130
204
 
131
- # Set up sampler for current conditional distribution
132
- sampler = self.samplers[par_name](self.target(**other_params))
205
+ if (Nt > 0) and (idx % Nt == 0):
206
+ self._store_samples()
133
207
 
134
- # Take a MCMC step
135
- current_samples[par_name] = sampler.step(current_samples[par_name])
208
+ # Call callback function if specified
209
+ self._call_callback(idx, Ns)
136
210
 
137
- # Ensure even 1-dimensional samples are 1D arrays
138
- current_samples[par_name] = current_samples[par_name].reshape(-1)
211
+ return self
212
+
213
+ def warmup(self, Nb, Nt=1, tune_freq=0.1) -> 'HybridGibbs':
214
+ """ Warmup (tune) the samplers in the Gibbs sampling scheme
215
+
216
+ Parameters
217
+ ----------
218
+ Nb : int
219
+ The number of samples to draw during warmup.
139
220
 
140
- return current_samples
221
+ Nt : int, optional, default=1
222
+ The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
223
+
224
+ tune_freq : float, optional
225
+ Frequency of tuning the samplers. Tuning is performed every tune_freq*Nb steps.
226
+
227
+ """
228
+
229
+ # Progress bar printing settings:
230
+ miniters = None if config.PROGRESS_BAR_DYNAMIC_UPDATE else Nb + 1
231
+ maxinterval = 10.0 if config.PROGRESS_BAR_DYNAMIC_UPDATE else float("inf")
232
+
233
+ tune_interval = max(int(tune_freq * Nb), 1)
234
+
235
+ for idx in tqdm(
236
+ range(Nb), "Warmup: ", miniters=miniters, maxinterval=maxinterval
237
+ ):
238
+
239
+ self.step()
141
240
 
142
- def step_tune(self, current_samples):
143
- """ Perform a single MCMC step for each parameter and tune the sampler """
144
- # Not implemented. No tuning happening here yet. Requires samplers to be able to be modified after initialization.
145
- return self.step(current_samples)
241
+ # Tune the sampler at tuning intervals (matching behavior of Sampler class)
242
+ if (idx + 1) % tune_interval == 0:
243
+ self.tune(tune_interval, idx // tune_interval)
244
+
245
+ if (Nt > 0) and (idx % Nt == 0):
246
+ self._store_samples()
247
+
248
+ # Call callback function if specified
249
+ self._call_callback(idx, Nb)
250
+
251
+ return self
252
+
253
+ def get_samples(self) -> Dict[str, Samples]:
254
+ samples_object = JointSamples()
255
+ for par_name in self.par_names:
256
+ samples_array = np.array(self.samples[par_name]).T
257
+ samples_object[par_name] = Samples(samples_array, self.target.get_density(par_name).geometry)
258
+ return samples_object
259
+
260
+ def step(self):
261
+ """ Sequentially go through all parameters and sample them conditionally on each other """
262
+
263
+ # Sample from each conditional distribution
264
+ for par_name in self.scan_order:
265
+
266
+ # Set target for current parameter
267
+ self._set_target(par_name)
268
+
269
+ # Get sampler
270
+ sampler = self.samplers[par_name]
271
+
272
+ # Instead of simply changing the target of the sampler, we reinitialize it.
273
+ # This is to ensure that all internal variables are set to match the new target.
274
+ # To return the sampler to the old state and history, we first extract the state and history
275
+ # before reinitializing the sampler and then set the state and history back to the sampler
276
+
277
+ # Extract state and history from sampler
278
+ sampler_state = sampler.get_state()
279
+ sampler_history = sampler.get_history()
280
+
281
+ # Reinitialize sampler
282
+ sampler.reinitialize()
283
+
284
+ # Set state and history back to sampler
285
+ sampler.set_state(sampler_state)
286
+ sampler.set_history(sampler_history)
287
+
288
+ # Allow for multiple sampling steps in each Gibbs step
289
+ for _ in range(self.num_sampling_steps[par_name]):
290
+ # Sampling step
291
+ acc = sampler.step()
292
+
293
+ # Store acceptance rate in sampler (matching behavior of Sampler class Sample method)
294
+ sampler._acc.append(acc)
295
+
296
+ # Extract samples (Ensure even 1-dimensional samples are 1D arrays)
297
+ if isinstance(sampler.current_point, np.ndarray):
298
+ self.current_samples[par_name] = sampler.current_point.reshape(-1)
299
+ else:
300
+ self.current_samples[par_name] = sampler.current_point
301
+
302
+ def tune(self, skip_len, update_count):
303
+ """ Run a single tuning step on each of the samplers in the Gibbs sampling scheme
304
+
305
+ Parameters
306
+ ----------
307
+ skip_len : int
308
+ Defines the number of steps in between tuning (i.e. the tuning interval).
309
+
310
+ update_count : int
311
+ The number of times tuning has been performed. Can be used for internal bookkeeping.
312
+
313
+ """
314
+ for par_name in self.par_names:
315
+ self.samplers[par_name].tune(skip_len=skip_len, update_count=update_count)
146
316
 
147
317
  # ------------ Private methods ------------
148
- def _allocate_samples(self, Ns):
149
- """ Allocate memory for samples """
150
- # Allocate memory for samples
151
- samples = {}
318
+ def _call_callback(self, sample_index, num_of_samples):
319
+ """ Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
320
+ if self.callback is not None:
321
+ self.callback(self, sample_index, num_of_samples)
322
+
323
+ def _initialize_samplers(self):
324
+ """ Initialize samplers """
325
+ for sampler in self.samplers.values():
326
+ sampler.initialize()
327
+
328
+ def _initialize_num_sampling_steps(self):
329
+ """ Initialize the number of sampling steps for each sampler. Defaults to 1 if not set by user """
330
+
331
+ if self.num_sampling_steps is None:
332
+ self.num_sampling_steps = {par_name: 1 for par_name in self.par_names}
333
+
152
334
  for par_name in self.par_names:
153
- samples[par_name] = np.zeros((self.target.get_density(par_name).dim, Ns))
154
-
155
- # Store samples in self
156
- if hasattr(self, 'samples'):
157
- # Append to existing samples (This makes a copy)
158
- for par_name in self.par_names:
159
- samples[par_name] = np.hstack((self.samples[par_name], samples[par_name]))
160
- self.samples = samples
335
+ if par_name not in self.num_sampling_steps:
336
+ self.num_sampling_steps[par_name] = 1
161
337
 
162
- def _allocate_samples_warmup(self, Nb):
163
- """ Allocate memory for samples """
164
-
165
- # If we already have warmup samples and more are requested raise error
166
- if hasattr(self, 'samples_warmup') and Nb != 0:
167
- raise ValueError('Sampler already has run warmup phase. Cannot run warmup phase again.')
338
+ def _set_targets(self):
339
+ """ Set targets for all samplers using the current samples """
340
+ par_names = self.par_names
341
+ for par_name in par_names:
342
+ self._set_target(par_name)
343
+
344
+ def _set_target(self, par_name):
345
+ """ Set target conditional distribution for a single parameter using the current samples """
346
+ # Get all other conditional parameters other than the current parameter and update the target
347
+ # This defines - from a joint p(x,y,z) - the conditional distribution p(x|y,z) or p(y|x,z) or p(z|x,y)
348
+ conditional_params = {par_name_: self.current_samples[par_name_] for par_name_ in self.par_names if par_name_ != par_name}
349
+ self.samplers[par_name].target = self.target(**conditional_params)
168
350
 
169
- # Allocate memory for samples
351
+ def _allocate_samples(self):
352
+ """ Allocate memory for samples """
170
353
  samples = {}
171
354
  for par_name in self.par_names:
172
- samples[par_name] = np.zeros((self.target.get_density(par_name).dim, Nb))
173
- self.samples_warmup = samples
355
+ samples[par_name] = []
356
+ self.samples = samples
174
357
 
175
358
  def _get_initial_points(self):
176
359
  """ Get initial points for each parameter """
177
360
  initial_points = {}
178
361
  for par_name in self.par_names:
179
- if hasattr(self, 'samples'):
180
- initial_points[par_name] = self.samples[par_name][:, -1]
181
- elif hasattr(self, 'samples_warmup'):
182
- initial_points[par_name] = self.samples_warmup[par_name][:, -1]
183
- elif hasattr(self.target.get_density(par_name), 'init_point'):
184
- initial_points[par_name] = self.target.get_density(par_name).init_point
185
- else:
186
- initial_points[par_name] = np.ones(self.target.get_density(par_name).dim)
362
+ sampler = self.samplers[par_name]
363
+ if sampler.initial_point is None:
364
+ sampler.initial_point = sampler._get_default_initial_point(self.target.get_density(par_name).dim)
365
+ initial_points[par_name] = sampler.initial_point
366
+
187
367
  return initial_points
188
368
 
189
- def _store_samples(self, samples, current_samples, i):
369
+ def _store_samples(self):
190
370
  """ Store current samples at index i of samples dict """
191
371
  for par_name in self.par_names:
192
- samples[par_name][:, i] = current_samples[par_name]
372
+ self.samples[par_name].append(self.current_samples[par_name])
193
373
 
194
- def _convert_to_Samples(self, samples):
195
- """ Convert each parameter in samples dict to cuqi.samples.Samples object with correct geometry """
196
- samples_object = {}
197
- for par_name in self.par_names:
198
- samples_object[par_name] = Samples(samples[par_name], self.target.get_density(par_name).geometry)
199
- return samples_object
200
-
201
- def _print_progress(self, s, Ns, phase):
202
- """Prints sampling progress"""
203
- if Ns < 2: # Don't print progress if only one sample
204
- return
205
- if (s % (max(Ns//100,1))) == 0:
206
- msg = f'{phase} {s} / {Ns}'
207
- sys.stdout.write('\r'+msg)
208
- if s==Ns:
209
- msg = f'{phase} {s} / {Ns}'
210
- sys.stdout.write('\r'+msg+'\n')
211
-
212
- # ------------ Private properties ------------
213
- @property
214
- def _Ns(self):
215
- """ Number of samples already taken """
216
- if hasattr(self, 'samples'):
217
- return self.samples[self.par_names[0]].shape[-1]
374
+ def __repr__(self):
375
+ """ Return a string representation of the sampler. """
376
+ msg = f"Sampler: {self.__class__.__name__} \n"
377
+ if self.target is None:
378
+ msg += f" Target: None \n"
218
379
  else:
219
- return 0
220
-
221
- @property
222
- def _Nb(self):
223
- """ Number of samples already taken in warmup phase """
224
- if hasattr(self, 'samples_warmup'):
225
- return self.samples_warmup[self.par_names[0]].shape[-1]
226
- else:
227
- return 0
380
+ msg += f" Target: \n \t {self.target} \n\n"
381
+
382
+ for key, value in zip(self.samplers.keys(), self.samplers.values()):
383
+ msg += f" Variable '{key}' with {value} \n"
384
+
385
+ return msg