CUQIpy 1.3.0__py3-none-any.whl → 1.4.0.post0.dev61__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/__init__.py +1 -1
  5. cuqi/distribution/_beta.py +1 -1
  6. cuqi/distribution/_cauchy.py +2 -2
  7. cuqi/distribution/_distribution.py +24 -15
  8. cuqi/distribution/_joint_distribution.py +97 -12
  9. cuqi/distribution/_posterior.py +9 -0
  10. cuqi/distribution/_truncated_normal.py +3 -3
  11. cuqi/distribution/_uniform.py +36 -2
  12. cuqi/experimental/__init__.py +1 -1
  13. cuqi/experimental/_recommender.py +216 -0
  14. cuqi/experimental/geometry/_productgeometry.py +3 -3
  15. cuqi/geometry/_geometry.py +12 -1
  16. cuqi/implicitprior/__init__.py +1 -1
  17. cuqi/implicitprior/_regularizedGaussian.py +40 -4
  18. cuqi/implicitprior/_restorator.py +35 -1
  19. cuqi/legacy/__init__.py +2 -0
  20. cuqi/legacy/sampler/__init__.py +11 -0
  21. cuqi/legacy/sampler/_conjugate.py +55 -0
  22. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  23. cuqi/legacy/sampler/_cwmh.py +196 -0
  24. cuqi/legacy/sampler/_gibbs.py +231 -0
  25. cuqi/legacy/sampler/_hmc.py +335 -0
  26. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  27. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  28. cuqi/legacy/sampler/_mh.py +190 -0
  29. cuqi/legacy/sampler/_pcn.py +244 -0
  30. cuqi/{experimental/mcmc → legacy/sampler}/_rto.py +134 -152
  31. cuqi/legacy/sampler/_sampler.py +182 -0
  32. cuqi/likelihood/_likelihood.py +1 -1
  33. cuqi/model/_model.py +1248 -357
  34. cuqi/pde/__init__.py +4 -0
  35. cuqi/pde/_observation_map.py +36 -0
  36. cuqi/pde/_pde.py +133 -32
  37. cuqi/problem/_problem.py +88 -82
  38. cuqi/sampler/__init__.py +120 -8
  39. cuqi/sampler/_conjugate.py +376 -35
  40. cuqi/sampler/_conjugate_approx.py +40 -16
  41. cuqi/sampler/_cwmh.py +132 -138
  42. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  43. cuqi/sampler/_gibbs.py +269 -130
  44. cuqi/sampler/_hmc.py +328 -201
  45. cuqi/sampler/_langevin_algorithm.py +282 -98
  46. cuqi/sampler/_laplace_approximation.py +87 -117
  47. cuqi/sampler/_mh.py +47 -157
  48. cuqi/sampler/_pcn.py +56 -211
  49. cuqi/sampler/_rto.py +206 -140
  50. cuqi/sampler/_sampler.py +540 -135
  51. cuqi/solver/_solver.py +6 -2
  52. cuqi/testproblem/_testproblem.py +2 -3
  53. cuqi/utilities/__init__.py +3 -1
  54. cuqi/utilities/_utilities.py +94 -12
  55. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/METADATA +6 -4
  56. cuqipy-1.4.0.post0.dev61.dist-info/RECORD +102 -0
  57. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/WHEEL +1 -1
  58. CUQIpy-1.3.0.dist-info/RECORD +0 -100
  59. cuqi/experimental/mcmc/__init__.py +0 -123
  60. cuqi/experimental/mcmc/_conjugate.py +0 -345
  61. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  62. cuqi/experimental/mcmc/_cwmh.py +0 -193
  63. cuqi/experimental/mcmc/_gibbs.py +0 -318
  64. cuqi/experimental/mcmc/_hmc.py +0 -464
  65. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -392
  66. cuqi/experimental/mcmc/_laplace_approximation.py +0 -156
  67. cuqi/experimental/mcmc/_mh.py +0 -80
  68. cuqi/experimental/mcmc/_pcn.py +0 -89
  69. cuqi/experimental/mcmc/_sampler.py +0 -566
  70. cuqi/experimental/mcmc/_utilities.py +0 -17
  71. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info/licenses}/LICENSE +0 -0
  72. {CUQIpy-1.3.0.dist-info → cuqipy-1.4.0.post0.dev61.dist-info}/top_level.txt +0 -0
@@ -1,89 +0,0 @@
1
- import numpy as np
2
- import cuqi
3
- from cuqi.experimental.mcmc import Sampler
4
- from cuqi.array import CUQIarray
5
-
6
- class PCN(Sampler): # Refactor to Proposal-based sampler?
7
-
8
- _STATE_KEYS = Sampler._STATE_KEYS.union({'scale', 'current_likelihood_logd', 'lambd'})
9
-
10
- def __init__(self, target=None, scale=1.0, **kwargs):
11
-
12
- super().__init__(target, **kwargs)
13
- self.initial_scale = scale
14
-
15
- def _initialize(self):
16
- self.scale = self.initial_scale
17
- self.current_likelihood_logd = self._loglikelihood(self.current_point)
18
-
19
- # parameters used in the Robbins-Monro recursion for tuning the scale parameter
20
- # see details and reference in the tune method
21
- self.lambd = self.scale
22
- self.star_acc = 0.44 #TODO: 0.234 # target acceptance rate
23
-
24
- def validate_target(self):
25
- if not isinstance(self.target, cuqi.distribution.Posterior):
26
- raise ValueError(f"To initialize an object of type {self.__class__}, 'target' need to be of type 'cuqi.distribution.Posterior'.")
27
- if not isinstance(self.prior, (cuqi.distribution.Gaussian, cuqi.distribution.Normal)):
28
- raise ValueError("The prior distribution of the target need to be Gaussian")
29
-
30
- def step(self):
31
- # propose state
32
- xi = self.prior.sample(1).flatten() # sample from the prior
33
- x_star = np.sqrt(1-self.scale**2)*self.current_point + self.scale*xi # PCN proposal
34
-
35
- # evaluate target
36
- loglike_eval_star = self._loglikelihood(x_star)
37
-
38
- # ratio and acceptance probability
39
- ratio = loglike_eval_star - self.current_likelihood_logd # proposal is symmetric
40
- alpha = min(0, ratio)
41
-
42
- # accept/reject
43
- acc = 0
44
- u_theta = np.log(np.random.rand())
45
- if (u_theta <= alpha):
46
- self.current_point = x_star
47
- self.current_likelihood_logd = loglike_eval_star
48
- acc = 1
49
-
50
- return acc
51
-
52
- @property
53
- def prior(self):
54
- return self.target.prior
55
-
56
- @property
57
- def likelihood(self):
58
- return self.target.likelihood
59
-
60
- def _loglikelihood(self, x):
61
- return self.likelihood.logd(x)
62
-
63
- @property
64
- def dim(self): # TODO. Check if we need this. Implemented in base class
65
- if hasattr(self,'target') and hasattr(self.target,'dim'):
66
- self._dim = self.target.dim
67
- elif hasattr(self,'target') and isinstance(self.target,tuple) and len(self.target)==2:
68
- self._dim = self.target[0].dim
69
- return self._dim
70
-
71
- def tune(self, skip_len, update_count):
72
- """
73
- Tune the scale parameter of the PCN sampler.
74
- The tuning is based on algorithm 4 in Andrieu, Christophe, and Johannes Thoms.
75
- "A tutorial on adaptive MCMC." Statistics and computing 18 (2008): 343-373.
76
- Note: the tuning algorithm here is the same as the one used in MH sampler.
77
- """
78
-
79
- # average acceptance rate in the past skip_len iterations
80
- hat_acc = np.mean(self._acc[-skip_len:])
81
-
82
- # new scaling parameter zeta to be used in the Robbins-Monro recursion
83
- zeta = 1/np.sqrt(update_count+1)
84
-
85
- # Robbins-Monro recursion to ensure that the variation of lambd vanishes
86
- self.lambd = np.exp(np.log(self.lambd) + zeta*(hat_acc-self.star_acc))
87
-
88
- # update scale parameter
89
- self.scale = min(self.lambd, 1)
@@ -1,566 +0,0 @@
1
- from abc import ABC, abstractmethod
2
- import os
3
- import numpy as np
4
- import pickle as pkl
5
- import warnings
6
- import cuqi
7
- from cuqi.samples import Samples
8
-
9
- try:
10
- from tqdm import tqdm
11
- except ImportError:
12
- def tqdm(iterable, **kwargs):
13
- warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
14
- return iterable
15
-
16
- class Sampler(ABC):
17
- """ Abstract base class for all samplers.
18
-
19
- Provides a common interface for all samplers. The interface includes methods for sampling, warmup and getting the samples in an object oriented way.
20
-
21
- Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
22
-
23
- The sampler maintains sets of state and history keys, which are used for features like checkpointing and resuming sampling.
24
-
25
- The state of the sampler represents all variables that are updated (replaced) in a Markov Monte Carlo step, e.g. the current point of the sampler.
26
-
27
- The history of the sampler represents all variables that are updated (appended) in a Markov Monte Carlo step, e.g. the samples and acceptance rates.
28
-
29
- Subclasses should ensure that any new variables that are updated in a Markov Monte Carlo step are added to the state or history keys.
30
-
31
- Saving and loading checkpoints saves and loads the state of the sampler (not the history).
32
-
33
- Batching samples via the batch_size parameter saves the sampler history to disk in batches of the specified size.
34
-
35
- Any other attribute stored as part of the sampler (e.g. target, initial_point) is not supposed to be updated
36
- during sampling and should not be part of the state or history.
37
-
38
- """
39
-
40
- _STATE_KEYS = {'current_point'}
41
- """ Set of keys for the state dictionary. """
42
-
43
- _HISTORY_KEYS = {'_samples', '_acc'}
44
- """ Set of keys for the history dictionary. """
45
-
46
- def __init__(self, target:cuqi.density.Density=None, initial_point=None, callback=None):
47
- """ Initializer for abstract base class for all samplers.
48
-
49
- Any subclassing samplers should simply store input parameters as part of the __init__ method.
50
-
51
- The actual initialization of the sampler should be done in the _initialize method.
52
-
53
- Parameters
54
- ----------
55
- target : cuqi.density.Density
56
- The target density.
57
-
58
- initial_point : array-like, optional
59
- The initial point for the sampler. If not given, the sampler will choose an initial point.
60
-
61
- callback : callable, optional
62
- A function that will be called after each sample is drawn. The function should take two arguments: the sample and the index of the sample.
63
- The sample is a 1D numpy array and the index is an integer. The callback function is useful for monitoring the sampler during sampling.
64
-
65
- """
66
-
67
- self.target = target
68
- self.initial_point = initial_point
69
- self.callback = callback
70
- self._is_initialized = False
71
-
72
- def initialize(self):
73
- """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
74
-
75
- if self._is_initialized:
76
- raise ValueError("Sampler is already initialized.")
77
-
78
- if self.target is None:
79
- raise ValueError("Cannot initialize sampler without a target density.")
80
-
81
- # Default values
82
- if self.initial_point is None:
83
- self.initial_point = self._get_default_initial_point(self.dim)
84
-
85
- # State variables
86
- self.current_point = self.initial_point
87
-
88
- # History variables
89
- self._samples = []
90
- self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
91
-
92
- self._initialize() # Subclass specific initialization
93
-
94
- self._validate_initialization()
95
-
96
- self._is_initialized = True
97
-
98
- # ------------ Abstract methods to be implemented by subclasses ------------
99
- @abstractmethod
100
- def step(self):
101
- """ Perform one step of the sampler by transitioning the current point to a new point according to the sampler's transition kernel. """
102
- pass
103
-
104
- @abstractmethod
105
- def tune(self, skip_len, update_count):
106
- """ Tune the parameters of the sampler. This method is called after each step of the warmup phase.
107
-
108
- Parameters
109
- ----------
110
- skip_len : int
111
- Defines the number of steps in between tuning (i.e. the tuning interval).
112
-
113
- update_count : int
114
- The number of times tuning has been performed. Can be used for internal bookkeeping.
115
-
116
- """
117
- pass
118
-
119
- @abstractmethod
120
- def validate_target(self):
121
- """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
122
- pass
123
-
124
- @abstractmethod
125
- def _initialize(self):
126
- """ Subclass specific sampler initialization. Called during the initialization of the sampler which is done before sampling starts. """
127
- pass
128
-
129
- # ------------ Public attributes ------------
130
- @property
131
- def dim(self) -> int:
132
- """ Dimension of the target density. """
133
- return self.target.dim
134
-
135
- @property
136
- def geometry(self) -> cuqi.geometry.Geometry:
137
- """ Geometry of the target density. """
138
- return self.target.geometry
139
-
140
- @property
141
- def target(self) -> cuqi.density.Density:
142
- """ Return the target density. """
143
- return self._target
144
-
145
- @target.setter
146
- def target(self, value):
147
- """ Set the target density. Runs validation of the target. """
148
- self._target = value
149
- if self._target is not None:
150
- self.validate_target()
151
-
152
- # ------------ Public methods ------------
153
- def get_samples(self) -> Samples:
154
- """ Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """
155
- return Samples(np.array(self._samples).T, self.target.geometry)
156
-
157
- def reinitialize(self):
158
- """ Re-initialize the sampler. This clears the state and history and initializes the sampler again by setting state and history to their original values. """
159
-
160
- # Loop over state and reset to None
161
- for key in self._STATE_KEYS:
162
- setattr(self, key, None)
163
-
164
- # Loop over history and reset to None
165
- for key in self._HISTORY_KEYS:
166
- setattr(self, key, None)
167
-
168
- self._is_initialized = False
169
-
170
- self.initialize()
171
-
172
- def save_checkpoint(self, path):
173
- """ Save the state of the sampler to a file. """
174
-
175
- self._ensure_initialized()
176
-
177
- state = self.get_state()
178
-
179
- # Convert all CUQIarrays to numpy arrays since CUQIarrays do not get pickled correctly
180
- for key, value in state['state'].items():
181
- if isinstance(value, cuqi.array.CUQIarray):
182
- state['state'][key] = value.to_numpy()
183
-
184
- with open(path, 'wb') as handle:
185
- pkl.dump(state, handle, protocol=pkl.HIGHEST_PROTOCOL)
186
-
187
- def load_checkpoint(self, path):
188
- """ Load the state of the sampler from a file. """
189
-
190
- self._ensure_initialized()
191
-
192
- with open(path, 'rb') as handle:
193
- state = pkl.load(handle)
194
-
195
- self.set_state(state)
196
-
197
- def sample(self, Ns, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
198
- """ Sample Ns samples from the target density.
199
-
200
- Parameters
201
- ----------
202
- Ns : int
203
- The number of samples to draw.
204
-
205
- batch_size : int, optional
206
- The batch size for saving samples to disk. If 0, no batching is used. If positive, samples are saved to disk in batches of the specified size.
207
-
208
- sample_path : str, optional
209
- The path to save the samples. If not specified, the samples are saved to the current working directory under a folder called 'CUQI_samples'.
210
-
211
- """
212
-
213
- self._ensure_initialized()
214
-
215
- # Initialize batch handler
216
- if batch_size > 0:
217
- batch_handler = _BatchHandler(batch_size, sample_path)
218
-
219
- # Draw samples
220
- pbar = tqdm(range(Ns), "Sample: ")
221
- for idx in pbar:
222
-
223
- # Perform one step of the sampler
224
- acc = self.step()
225
-
226
- # Store samples
227
- self._acc.append(acc)
228
- self._samples.append(self.current_point)
229
-
230
- # display acc rate at progress bar
231
- pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
232
-
233
- # Add sample to batch
234
- if batch_size > 0:
235
- batch_handler.add_sample(self.current_point)
236
-
237
- # Call callback function if specified
238
- self._call_callback(self.current_point, len(self._samples)-1)
239
-
240
- return self
241
-
242
-
243
- def warmup(self, Nb, tune_freq=0.1) -> 'Sampler':
244
- """ Warmup the sampler by drawing Nb samples.
245
-
246
- Parameters
247
- ----------
248
- Nb : int
249
- The number of samples to draw during warmup.
250
-
251
- tune_freq : float, optional
252
- The frequency of tuning. Tuning is performed every tune_freq*Nb samples.
253
-
254
- """
255
-
256
- self._ensure_initialized()
257
-
258
- tune_interval = max(int(tune_freq * Nb), 1)
259
-
260
- # Draw warmup samples with tuning
261
- pbar = tqdm(range(Nb), "Warmup: ")
262
- for idx in pbar:
263
-
264
- # Perform one step of the sampler
265
- acc = self.step()
266
-
267
- # Tune the sampler at tuning intervals
268
- if (idx + 1) % tune_interval == 0:
269
- self.tune(tune_interval, idx // tune_interval)
270
-
271
- # Store samples
272
- self._acc.append(acc)
273
- self._samples.append(self.current_point)
274
-
275
- # display acc rate at progress bar
276
- pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
277
-
278
- # Call callback function if specified
279
- self._call_callback(self.current_point, len(self._samples)-1)
280
-
281
- return self
282
-
283
- def get_state(self) -> dict:
284
- """ Return the state of the sampler.
285
-
286
- The state is used when checkpointing the sampler.
287
-
288
- The state of the sampler is a dictionary with keys 'metadata' and 'state'.
289
- The 'metadata' key contains information about the sampler type.
290
- The 'state' key contains the state of the sampler.
291
-
292
- For example, the state of a "MH" sampler could be:
293
-
294
- state = {
295
- 'metadata': {
296
- 'sampler_type': 'MH'
297
- },
298
- 'state': {
299
- 'current_point': np.array([...]),
300
- 'current_target_logd': -123.45,
301
- 'scale': 1.0,
302
- ...
303
- }
304
- }
305
- """
306
- state = {
307
- 'metadata': {
308
- 'sampler_type': self.__class__.__name__
309
- },
310
- 'state': {
311
- key: getattr(self, key) for key in self._STATE_KEYS
312
- }
313
- }
314
- return state
315
-
316
- def set_state(self, state: dict):
317
- """ Set the state of the sampler.
318
-
319
- The state is used when loading the sampler from a checkpoint.
320
-
321
- The state of the sampler is a dictionary with keys 'metadata' and 'state'.
322
-
323
- For example, the state of a "MH" sampler could be:
324
-
325
- state = {
326
- 'metadata': {
327
- 'sampler_type': 'MH'
328
- },
329
- 'state': {
330
- 'current_point': np.array([...]),
331
- 'current_target_logd': -123.45,
332
- 'scale': 1.0,
333
- ...
334
- }
335
- }
336
- """
337
- if state['metadata']['sampler_type'] != self.__class__.__name__:
338
- raise ValueError(f"Sampler type in state dictionary ({state['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
339
-
340
- for key, value in state['state'].items():
341
- if key in self._STATE_KEYS:
342
- setattr(self, key, value)
343
- else:
344
- raise ValueError(f"Key {key} not recognized in state dictionary of sampler {self.__class__.__name__}.")
345
-
346
- def get_history(self) -> dict:
347
- """ Return the history of the sampler. """
348
- history = {
349
- 'metadata': {
350
- 'sampler_type': self.__class__.__name__
351
- },
352
- 'history': {
353
- key: getattr(self, key) for key in self._HISTORY_KEYS
354
- }
355
- }
356
- return history
357
-
358
- def set_history(self, history: dict):
359
- """ Set the history of the sampler. """
360
- if history['metadata']['sampler_type'] != self.__class__.__name__:
361
- raise ValueError(f"Sampler type in history dictionary ({history['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
362
-
363
- for key, value in history['history'].items():
364
- if key in self._HISTORY_KEYS:
365
- setattr(self, key, value)
366
- else:
367
- raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
368
-
369
- # ------------ Private methods ------------
370
- def _call_callback(self, sample, sample_index):
371
- """ Calls the callback function. Assumes input is sample and sample index"""
372
- if self.callback is not None:
373
- self.callback(sample, sample_index)
374
-
375
- def _validate_initialization(self):
376
- """ Validate the initialization of the sampler by checking all state and history keys are set. """
377
-
378
- for key in self._STATE_KEYS:
379
- if getattr(self, key) is None:
380
- raise ValueError(f"Sampler state key {key} is not set after initialization.")
381
-
382
- for key in self._HISTORY_KEYS:
383
- if getattr(self, key) is None:
384
- raise ValueError(f"Sampler history key {key} is not set after initialization.")
385
-
386
- def _ensure_initialized(self):
387
- """ Ensure the sampler is initialized. If not initialize it. """
388
- if not self._is_initialized:
389
- self.initialize()
390
-
391
- def _get_default_initial_point(self, dim):
392
- """ Return the default initial point for the sampler. Defaults to an array of ones. """
393
- return np.ones(dim)
394
-
395
- def __repr__(self):
396
- """ Return a string representation of the sampler. """
397
- if self.target is None:
398
- return f"Sampler: {self.__class__.__name__} \n Target: None"
399
- else:
400
- msg = f"Sampler: {self.__class__.__name__} \n Target: \n \t {self.target} "
401
-
402
- if self._is_initialized:
403
- state = self.get_state()
404
- msg += f"\n Current state: \n"
405
- # Sort keys alphabetically
406
- keys = sorted(state['state'].keys())
407
- # Put _ in the end
408
- keys = [key for key in keys if key[0] != '_'] + [key for key in keys if key[0] == '_']
409
- for key in keys:
410
- value = state['state'][key]
411
- msg += f"\t {key}: {value} \n"
412
- return msg
413
-
414
- class ProposalBasedSampler(Sampler, ABC):
415
- """ Abstract base class for samplers that use a proposal distribution. """
416
-
417
- _STATE_KEYS = Sampler._STATE_KEYS.union({'current_target_logd', 'scale'})
418
-
419
- def __init__(self, target=None, proposal=None, scale=1, **kwargs):
420
- """ Initializer for abstract base class for samplers that use a proposal distribution.
421
-
422
- Any subclassing samplers should simply store input parameters as part of the __init__ method.
423
-
424
- Initialization of the sampler should be done in the _initialize method.
425
-
426
- See :class:`Sampler` for additional details.
427
-
428
- Parameters
429
- ----------
430
- target : cuqi.density.Density
431
- The target density.
432
-
433
- proposal : cuqi.distribution.Distribution, optional
434
- The proposal distribution. If not specified, the default proposal is used.
435
-
436
- scale : float, optional
437
- The scale parameter for the proposal distribution.
438
-
439
- **kwargs : dict
440
- Additional keyword arguments passed to the :class:`Sampler` initializer.
441
-
442
- """
443
-
444
- super().__init__(target, **kwargs)
445
- self.proposal = proposal
446
- self.initial_scale = scale
447
-
448
- def initialize(self):
449
- """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
450
-
451
- if self._is_initialized:
452
- raise ValueError("Sampler is already initialized.")
453
-
454
- if self.target is None:
455
- raise ValueError("Cannot initialize sampler without a target density.")
456
-
457
- # Default values
458
- if self.initial_point is None:
459
- self.initial_point = self._get_default_initial_point(self.dim)
460
-
461
- if self.proposal is None:
462
- self.proposal = self._default_proposal
463
-
464
- # State variables
465
- self.current_point = self.initial_point
466
- self.scale = self.initial_scale
467
-
468
- self.current_target_logd = self.target.logd(self.current_point)
469
-
470
- # History variables
471
- self._samples = []
472
- self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
473
-
474
- self._initialize() # Subclass specific initialization
475
-
476
- self._validate_initialization()
477
-
478
- self._is_initialized = True
479
-
480
- @abstractmethod
481
- def validate_proposal(self):
482
- """ Validate the proposal distribution. """
483
- pass
484
-
485
- @property
486
- def _default_proposal(self):
487
- """ Return the default proposal distribution. Defaults to a Gaussian distribution with zero mean and unit variance. """
488
- return cuqi.distribution.Gaussian(np.zeros(self.dim), 1)
489
-
490
- @property
491
- def proposal(self):
492
- """ The proposal distribution. """
493
- return self._proposal
494
-
495
- @proposal.setter
496
- def proposal(self, proposal):
497
- """ Set the proposal distribution. """
498
- self._proposal = proposal
499
- if self._proposal is not None:
500
- self.validate_proposal()
501
-
502
-
503
- class _BatchHandler:
504
- """ Utility class to handle batching of samples.
505
-
506
- If a batch size is specified, this class will save samples to disk in batches of the specified size.
507
-
508
- This is useful for very large sample sets that do not fit in memory.
509
-
510
- """
511
-
512
- def __init__(self, batch_size=0, sample_path='./CUQI_samples/'):
513
-
514
- if batch_size < 0:
515
- raise ValueError("Batch size should be a non-negative integer")
516
-
517
- self.sample_path = sample_path
518
- self._batch_size = batch_size
519
- self.current_batch = []
520
- self.num_batches_dumped = 0
521
-
522
- @property
523
- def sample_path(self):
524
- """ The path to save the samples. """
525
- return self._sample_path
526
-
527
- @sample_path.setter
528
- def sample_path(self, value):
529
- if not isinstance(value, str):
530
- raise TypeError("Sample path must be a string.")
531
- normalized_path = value.rstrip('/') + '/'
532
- if not os.path.isdir(normalized_path):
533
- try:
534
- os.makedirs(normalized_path, exist_ok=True)
535
- except Exception as e:
536
- raise ValueError(f"Could not create directory at {normalized_path}: {e}")
537
- self._sample_path = normalized_path
538
-
539
- def add_sample(self, sample):
540
- """ Add a sample to the batch if batching. If the batch is full, flush the batch to disk. """
541
-
542
- if self._batch_size <= 0:
543
- return # Batching not used
544
-
545
- self.current_batch.append(sample)
546
-
547
- if len(self.current_batch) >= self._batch_size:
548
- self.flush()
549
-
550
- def flush(self):
551
- """ Flush the current batch of samples to disk. """
552
-
553
- if not self.current_batch:
554
- return # No samples to flush
555
-
556
- # Save the current batch of samples
557
- batch_samples = np.array(self.current_batch)
558
- file_path = f'{self.sample_path}batch_{self.num_batches_dumped:04d}.npz'
559
- np.savez(file_path, samples=batch_samples, batch_id=self.num_batches_dumped)
560
-
561
- self.num_batches_dumped += 1
562
- self.current_batch = [] # Clear the batch after saving
563
-
564
- def finalize(self):
565
- """ Finalize the batch handler. Flush any remaining samples to disk. """
566
- self.flush()
@@ -1,17 +0,0 @@
1
- import cuqi
2
- import inspect
3
-
4
- def find_valid_samplers(target):
5
- """ Finds all samplers in the cuqi.experimental.mcmc module that accept the provided target. """
6
-
7
- all_samplers = [(name, cls) for name, cls in inspect.getmembers(cuqi.experimental.mcmc, inspect.isclass) if issubclass(cls, cuqi.experimental.mcmc.Sampler)]
8
- valid_samplers = []
9
-
10
- for name, sampler in all_samplers:
11
- try:
12
- sampler(target)
13
- valid_samplers += [name]
14
- except:
15
- pass
16
-
17
- return valid_samplers