CUQIpy 1.4.0.post0.dev13__py3-none-any.whl → 1.4.0.post0.dev45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (48) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/experimental/__init__.py +1 -2
  4. cuqi/experimental/_recommender.py +4 -4
  5. cuqi/legacy/__init__.py +2 -0
  6. cuqi/legacy/sampler/__init__.py +11 -0
  7. cuqi/legacy/sampler/_conjugate.py +55 -0
  8. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  9. cuqi/legacy/sampler/_cwmh.py +196 -0
  10. cuqi/legacy/sampler/_gibbs.py +231 -0
  11. cuqi/legacy/sampler/_hmc.py +335 -0
  12. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  13. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  14. cuqi/legacy/sampler/_mh.py +190 -0
  15. cuqi/legacy/sampler/_pcn.py +244 -0
  16. cuqi/legacy/sampler/_rto.py +284 -0
  17. cuqi/legacy/sampler/_sampler.py +182 -0
  18. cuqi/problem/_problem.py +87 -80
  19. cuqi/sampler/__init__.py +120 -8
  20. cuqi/sampler/_conjugate.py +376 -35
  21. cuqi/sampler/_conjugate_approx.py +40 -16
  22. cuqi/sampler/_cwmh.py +132 -138
  23. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  24. cuqi/sampler/_gibbs.py +269 -130
  25. cuqi/sampler/_hmc.py +328 -201
  26. cuqi/sampler/_langevin_algorithm.py +282 -98
  27. cuqi/sampler/_laplace_approximation.py +87 -117
  28. cuqi/sampler/_mh.py +47 -157
  29. cuqi/sampler/_pcn.py +56 -211
  30. cuqi/sampler/_rto.py +206 -140
  31. cuqi/sampler/_sampler.py +540 -135
  32. {cuqipy-1.4.0.post0.dev13.dist-info → cuqipy-1.4.0.post0.dev45.dist-info}/METADATA +1 -1
  33. {cuqipy-1.4.0.post0.dev13.dist-info → cuqipy-1.4.0.post0.dev45.dist-info}/RECORD +36 -35
  34. cuqi/experimental/mcmc/__init__.py +0 -122
  35. cuqi/experimental/mcmc/_conjugate.py +0 -396
  36. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  37. cuqi/experimental/mcmc/_cwmh.py +0 -190
  38. cuqi/experimental/mcmc/_gibbs.py +0 -366
  39. cuqi/experimental/mcmc/_hmc.py +0 -462
  40. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
  41. cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
  42. cuqi/experimental/mcmc/_mh.py +0 -80
  43. cuqi/experimental/mcmc/_pcn.py +0 -89
  44. cuqi/experimental/mcmc/_rto.py +0 -350
  45. cuqi/experimental/mcmc/_sampler.py +0 -582
  46. {cuqipy-1.4.0.post0.dev13.dist-info → cuqipy-1.4.0.post0.dev45.dist-info}/WHEEL +0 -0
  47. {cuqipy-1.4.0.post0.dev13.dist-info → cuqipy-1.4.0.post0.dev45.dist-info}/licenses/LICENSE +0 -0
  48. {cuqipy-1.4.0.post0.dev13.dist-info → cuqipy-1.4.0.post0.dev45.dist-info}/top_level.txt +0 -0
cuqi/sampler/_sampler.py CHANGED
@@ -1,177 +1,582 @@
1
1
  from abc import ABC, abstractmethod
2
- import sys
2
+ import os
3
3
  import numpy as np
4
+ import pickle as pkl
5
+ import warnings
4
6
  import cuqi
5
7
  from cuqi.samples import Samples
6
8
 
9
+ try:
10
+ from tqdm import tqdm
11
+ except ImportError:
12
+ def tqdm(iterable, **kwargs):
13
+ warnings.warn("Module mcmc: tqdm not found. Install tqdm to get sampling progress.")
14
+ return iterable
15
+
7
16
  class Sampler(ABC):
17
+ """ Abstract base class for all samplers.
18
+
19
+ Provides a common interface for all samplers. The interface includes methods for sampling, warmup and getting the samples in an object oriented way.
8
20
 
9
- def __init__(self, target, x0=None, dim=None, callback=None):
21
+ Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
10
22
 
11
- self._dim = dim
12
- if hasattr(target,'dim'):
13
- if self._dim is None:
14
- self._dim = target.dim
15
- elif self._dim != target.dim:
16
- raise ValueError("'dim' need to be None or equal to 'target.dim'")
17
- elif x0 is not None:
18
- self._dim = len(x0)
23
+ The sampler maintains sets of state and history keys, which are used for features like checkpointing and resuming sampling.
19
24
 
20
- self.target = target
25
+ The state of the sampler represents all variables that are updated (replaced) in a Markov Monte Carlo step, e.g. the current point of the sampler.
21
26
 
22
- if x0 is None:
23
- x0 = np.ones(self.dim)
24
- self.x0 = x0
27
+ The history of the sampler represents all variables that are updated (appended) in a Markov Monte Carlo step, e.g. the samples and acceptance rates.
25
28
 
26
- self.callback = callback
29
+ Subclasses should ensure that any new variables that are updated in a Markov Monte Carlo step are added to the state or history keys.
27
30
 
28
- def step(self, x):
29
- """
30
- Perform a single MCMC step
31
- """
32
- # Currently a hack to get step method for any sampler
33
- self.x0 = x
34
- return self.sample(2).samples[:,-1]
31
+ Saving and loading checkpoints saves and loads the state of the sampler (not the history).
35
32
 
36
- def step_tune(self, x, *args, **kwargs):
37
- """
38
- Perform a single MCMC step and tune the sampler. This is used during burn-in.
39
- """
40
- # Currently a hack to get step method for any sampler
41
- out = self.step(x)
42
- self.tune(*args, *kwargs)
43
- return out
33
+ Batching samples via the batch_size parameter saves the sampler history to disk in batches of the specified size.
34
+
35
+ Any other attribute stored as part of the sampler (e.g. target, initial_point) is not supposed to be updated
36
+ during sampling and should not be part of the state or history.
37
+
38
+ """
44
39
 
45
- def tune(self):
40
+ _STATE_KEYS = {'current_point'}
41
+ """ Set of keys for the state dictionary. """
42
+
43
+ _HISTORY_KEYS = {'_samples', '_acc'}
44
+ """ Set of keys for the history dictionary. """
45
+
46
+ def __init__(self, target:cuqi.density.Density=None, initial_point=None, callback=None):
47
+ """ Initializer for abstract base class for all samplers.
48
+
49
+ Any subclassing samplers should simply store input parameters as part of the __init__ method.
50
+
51
+ The actual initialization of the sampler should be done in the _initialize method.
52
+
53
+ Parameters
54
+ ----------
55
+ target : cuqi.density.Density
56
+ The target density.
57
+
58
+ initial_point : array-like, optional
59
+ The initial point for the sampler. If not given, the sampler will choose an initial point.
60
+
61
+ callback : callable, optional
62
+ A function that will be called after each sampling step. It can be useful for monitoring the sampler during sampling.
63
+ The function should take three arguments: the sampler object, the index of the current sampling step, the total number of requested samples. The last two arguments are integers. An example of the callback function signature is: `callback(sampler, sample_index, num_of_samples)`.
46
64
  """
47
- Tune the sampler parameters.
65
+
66
+ self.target = target
67
+ self.initial_point = initial_point
68
+ self.callback = callback
69
+ self._is_initialized = False
70
+
71
+ def initialize(self):
72
+ """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
73
+
74
+ if self._is_initialized:
75
+ raise ValueError("Sampler is already initialized.")
76
+
77
+ if self.target is None:
78
+ raise ValueError("Cannot initialize sampler without a target density.")
79
+
80
+ # Default values
81
+ if self.initial_point is None:
82
+ self.initial_point = self._get_default_initial_point(self.dim)
83
+
84
+ # State variables
85
+ self.current_point = self.initial_point
86
+
87
+ # History variables
88
+ self._samples = []
89
+ self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
90
+
91
+ self._initialize() # Subclass specific initialization
92
+
93
+ self._validate_initialization()
94
+
95
+ self._is_initialized = True
96
+
97
+ # ------------ Abstract methods to be implemented by subclasses ------------
98
+ @abstractmethod
99
+ def step(self):
100
+ """ Perform one step of the sampler by transitioning the current point to a new point according to the sampler's transition kernel. """
101
+ pass
102
+
103
+ @abstractmethod
104
+ def tune(self, skip_len, update_count):
105
+ """ Tune the parameters of the sampler. This method is called after each step of the warmup phase.
106
+
107
+ Parameters
108
+ ----------
109
+ skip_len : int
110
+ Defines the number of steps in between tuning (i.e. the tuning interval).
111
+
112
+ update_count : int
113
+ The number of times tuning has been performed. Can be used for internal bookkeeping.
114
+
48
115
  """
49
116
  pass
50
117
 
118
+ @abstractmethod
119
+ def validate_target(self):
120
+ """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
121
+ pass
51
122
 
52
- @property
53
- def geometry(self):
54
- if hasattr(self, 'target') and hasattr(self.target, 'geometry'):
55
- geom = self.target.geometry
56
- else:
57
- geom = cuqi.geometry._DefaultGeometry1D(self.dim)
58
- return geom
123
+ @abstractmethod
124
+ def _initialize(self):
125
+ """ Subclass specific sampler initialization. Called during the initialization of the sampler which is done before sampling starts. """
126
+ pass
59
127
 
60
- @property
61
- def target(self):
62
- return self._target
128
+ # ------------ Public attributes ------------
129
+ @property
130
+ def dim(self) -> int:
131
+ """ Dimension of the target density. """
132
+ return self.target.dim
63
133
 
64
- @target.setter
134
+ @property
135
+ def geometry(self) -> cuqi.geometry.Geometry:
136
+ """ Geometry of the target density. """
137
+ return self.target.geometry
138
+
139
+ @property
140
+ def target(self) -> cuqi.density.Density:
141
+ """ Return the target density. """
142
+ return self._target
143
+
144
+ @target.setter
65
145
  def target(self, value):
66
- if not isinstance(value, cuqi.distribution.Distribution) and callable(value):
67
- # obtain self.dim
68
- if self.dim is not None:
69
- dim = self.dim
70
- else:
71
- raise ValueError(f"If 'target' is a lambda function, the parameter 'dim' need to be specified when initializing {self.__class__}.")
146
+ """ Set the target density. Runs validation of the target. """
147
+ self._target = value
148
+ if self._target is not None:
149
+ self.validate_target()
72
150
 
73
- # set target
74
- self._target = cuqi.distribution.UserDefinedDistribution(logpdf_func=value, dim = dim)
151
+ @property
152
+ def current_point(self):
153
+ """ The current point of the sampler. """
154
+ return self._current_point
75
155
 
76
- elif isinstance(value, cuqi.distribution.Distribution):
77
- self._target = value
78
- else:
79
- raise ValueError("'target' need to be either a lambda function or of type 'cuqi.distribution.Distribution'")
156
+ @current_point.setter
157
+ def current_point(self, value):
158
+ """ Set the current point of the sampler. """
159
+ self._current_point = value
80
160
 
161
+ # ------------ Public methods ------------
162
+ def get_samples(self) -> Samples:
163
+ """ Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """
164
+ return Samples(np.array(self._samples).T, self.target.geometry)
165
+
166
+ def reinitialize(self):
167
+ """ Re-initialize the sampler. This clears the state and history and initializes the sampler again by setting state and history to their original values. """
81
168
 
82
- @property
83
- def dim(self):
84
- if hasattr(self,'target') and hasattr(self.target,'dim'):
85
- self._dim = self.target.dim
86
- return self._dim
169
+ # Loop over state and reset to None
170
+ for key in self._STATE_KEYS:
171
+ setattr(self, key, None)
172
+
173
+ # Loop over history and reset to None
174
+ for key in self._HISTORY_KEYS:
175
+ setattr(self, key, None)
176
+
177
+ self._is_initialized = False
178
+
179
+ self.initialize()
87
180
 
181
+ def save_checkpoint(self, path):
182
+ """ Save the state of the sampler to a file. """
88
183
 
89
- def sample(self,N,Nb=0):
90
- # Get samples from the samplers sample method
91
- result = self._sample(N,Nb)
92
- return self._create_Sample_object(result,N+Nb)
93
-
94
- def sample_adapt(self,N,Nb=0):
95
- # Get samples from the samplers sample method
96
- result = self._sample_adapt(N,Nb)
97
- return self._create_Sample_object(result,N+Nb)
98
-
99
- def _create_Sample_object(self,result,N):
100
- loglike_eval = None
101
- acc_rate = None
102
- if isinstance(result,tuple):
103
- #Unpack samples+loglike+acc_rate
104
- s = result[0]
105
- if len(result)>1: loglike_eval = result[1]
106
- if len(result)>2: acc_rate = result[2]
107
- if len(result)>3: raise TypeError("Expected tuple of at most 3 elements from sampling method.")
108
- else:
109
- s = result
184
+ self._ensure_initialized()
185
+
186
+ state = self.get_state()
187
+
188
+ # Convert all CUQIarrays to numpy arrays since CUQIarrays do not get pickled correctly
189
+ for key, value in state['state'].items():
190
+ if isinstance(value, cuqi.array.CUQIarray):
191
+ state['state'][key] = value.to_numpy()
192
+
193
+ with open(path, 'wb') as handle:
194
+ pkl.dump(state, handle, protocol=pkl.HIGHEST_PROTOCOL)
195
+
196
+ def load_checkpoint(self, path):
197
+ """ Load the state of the sampler from a file. """
198
+
199
+ self._ensure_initialized()
200
+
201
+ with open(path, 'rb') as handle:
202
+ state = pkl.load(handle)
203
+
204
+ self.set_state(state)
205
+
206
+ def sample(self, Ns, Nt=1, batch_size=0, sample_path='./CUQI_samples/') -> 'Sampler':
207
+ """ Sample Ns samples from the target density.
208
+
209
+ Parameters
210
+ ----------
211
+ Ns : int
212
+ The number of samples to draw.
213
+
214
+ Nt : int, optional, default=1
215
+ The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
216
+
217
+ batch_size : int, optional
218
+ The batch size for saving samples to disk. If 0, no batching is used. If positive, samples are saved to disk in batches of the specified size.
219
+
220
+ sample_path : str, optional
221
+ The path to save the samples. If not specified, the samples are saved to the current working directory under a folder called 'CUQI_samples'.
222
+
223
+ """
224
+ self._ensure_initialized()
225
+
226
+ # Initialize batch handler
227
+ if batch_size > 0:
228
+ batch_handler = _BatchHandler(batch_size, sample_path)
229
+
230
+ # Draw samples
231
+ pbar = tqdm(range(Ns), "Sample: ")
232
+ for idx in pbar:
233
+
234
+ # Perform one step of the sampler
235
+ acc = self.step()
236
+
237
+ # Store samples
238
+ self._acc.append(acc)
239
+ if (Nt > 0) and ((idx + 1) % Nt == 0):
240
+ self._samples.append(self.current_point)
241
+
242
+ # display acc rate at progress bar
243
+ pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
244
+
245
+ # Add sample to batch
246
+ if batch_size > 0:
247
+ batch_handler.add_sample(self.current_point)
248
+
249
+ # Call callback function if specified
250
+ self._call_callback(idx, Ns)
110
251
 
111
- #Store samples in cuqi samples object if more than 1 sample
112
- if N==1:
113
- if len(s) == 1 and isinstance(s,np.ndarray): #Extract single value from numpy array
114
- s = s.ravel()[0]
252
+ return self
253
+
254
+
255
+ def warmup(self, Nb, Nt=1, tune_freq=0.1) -> 'Sampler':
256
+ """ Warmup the sampler by drawing Nb samples.
257
+
258
+ Parameters
259
+ ----------
260
+ Nb : int
261
+ The number of samples to draw during warmup.
262
+
263
+ Nt : int, optional, default=1
264
+ The thinning interval. If Nt >= 1, every Nt'th sample is stored. The larger Nt, the fewer samples are stored.
265
+
266
+ tune_freq : float, optional
267
+ The frequency of tuning. Tuning is performed every tune_freq*Nb samples.
268
+
269
+ """
270
+
271
+ self._ensure_initialized()
272
+
273
+ tune_interval = max(int(tune_freq * Nb), 1)
274
+
275
+ # Draw warmup samples with tuning
276
+ pbar = tqdm(range(Nb), "Warmup: ")
277
+ for idx in pbar:
278
+
279
+ # Perform one step of the sampler
280
+ acc = self.step()
281
+
282
+ # Tune the sampler at tuning intervals
283
+ if (idx + 1) % tune_interval == 0:
284
+ self.tune(tune_interval, idx // tune_interval)
285
+
286
+ # Store samples
287
+ self._acc.append(acc)
288
+ if (Nt > 0) and ((idx + 1) % Nt == 0):
289
+ self._samples.append(self.current_point)
290
+
291
+ # display acc rate at progress bar
292
+ pbar.set_postfix_str(f"acc rate: {np.mean(self._acc[-1-idx:]):.2%}")
293
+
294
+ # Call callback function if specified
295
+ self._call_callback(idx, Nb)
296
+
297
+ return self
298
+
299
+ def get_state(self) -> dict:
300
+ """ Return the state of the sampler.
301
+
302
+ The state is used when checkpointing the sampler.
303
+
304
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
305
+ The 'metadata' key contains information about the sampler type.
306
+ The 'state' key contains the state of the sampler.
307
+
308
+ For example, the state of a "MH" sampler could be:
309
+
310
+ state = {
311
+ 'metadata': {
312
+ 'sampler_type': 'MH'
313
+ },
314
+ 'state': {
315
+ 'current_point': np.array([...]),
316
+ 'current_target_logd': -123.45,
317
+ 'scale': 1.0,
318
+ ...
319
+ }
320
+ }
321
+ """
322
+ state = {
323
+ 'metadata': {
324
+ 'sampler_type': self.__class__.__name__
325
+ },
326
+ 'state': {
327
+ key: getattr(self, key) for key in self._STATE_KEYS
328
+ }
329
+ }
330
+ return state
331
+
332
+ def set_state(self, state: dict):
333
+ """ Set the state of the sampler.
334
+
335
+ The state is used when loading the sampler from a checkpoint.
336
+
337
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
338
+
339
+ For example, the state of a "MH" sampler could be:
340
+
341
+ state = {
342
+ 'metadata': {
343
+ 'sampler_type': 'MH'
344
+ },
345
+ 'state': {
346
+ 'current_point': np.array([...]),
347
+ 'current_target_logd': -123.45,
348
+ 'scale': 1.0,
349
+ ...
350
+ }
351
+ }
352
+ """
353
+ if state['metadata']['sampler_type'] != self.__class__.__name__:
354
+ raise ValueError(f"Sampler type in state dictionary ({state['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
355
+
356
+ for key, value in state['state'].items():
357
+ if key in self._STATE_KEYS:
358
+ setattr(self, key, value)
115
359
  else:
116
- s = s.flatten()
360
+ raise ValueError(f"Key {key} not recognized in state dictionary of sampler {self.__class__.__name__}.")
361
+
362
+ def get_history(self) -> dict:
363
+ """ Return the history of the sampler. """
364
+ history = {
365
+ 'metadata': {
366
+ 'sampler_type': self.__class__.__name__
367
+ },
368
+ 'history': {
369
+ key: getattr(self, key) for key in self._HISTORY_KEYS
370
+ }
371
+ }
372
+ return history
373
+
374
+ def set_history(self, history: dict):
375
+ """ Set the history of the sampler. """
376
+ if history['metadata']['sampler_type'] != self.__class__.__name__:
377
+ raise ValueError(f"Sampler type in history dictionary ({history['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
378
+
379
+ for key, value in history['history'].items():
380
+ if key in self._HISTORY_KEYS:
381
+ setattr(self, key, value)
382
+ else:
383
+ raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
384
+
385
+ # ------------ Private methods ------------
386
+ def _call_callback(self, sample_index, num_of_samples):
387
+ """ Calls the callback function. Assumes input is sampler, sample index, and total number of samples """
388
+ if self.callback is not None:
389
+ self.callback(self, sample_index, num_of_samples)
390
+
391
+ def _validate_initialization(self):
392
+ """ Validate the initialization of the sampler by checking all state and history keys are set. """
393
+
394
+ for key in self._STATE_KEYS:
395
+ if getattr(self, key) is None:
396
+ raise ValueError(f"Sampler state key {key} is not set after initialization.")
397
+
398
+ for key in self._HISTORY_KEYS:
399
+ if getattr(self, key) is None:
400
+ raise ValueError(f"Sampler history key {key} is not set after initialization.")
401
+
402
+ def _ensure_initialized(self):
403
+ """ Ensure the sampler is initialized. If not initialize it. """
404
+ if not self._is_initialized:
405
+ self.initialize()
406
+
407
+ def _get_default_initial_point(self, dim):
408
+ """ Return the default initial point for the sampler. Defaults to an array of ones. """
409
+ return np.ones(dim)
410
+
411
+ def __repr__(self):
412
+ """ Return a string representation of the sampler. """
413
+ if self.target is None:
414
+ return f"Sampler: {self.__class__.__name__} \n Target: None"
117
415
  else:
118
- s = Samples(s, self.geometry)#, geometry = self.geometry)
119
- s.loglike_eval = loglike_eval
120
- s.acc_rate = acc_rate
121
- return s
416
+ msg = f"Sampler: {self.__class__.__name__} \n Target: \n \t {self.target} "
417
+
418
+ if self._is_initialized:
419
+ state = self.get_state()
420
+ msg += f"\n Current state: \n"
421
+ # Sort keys alphabetically
422
+ keys = sorted(state['state'].keys())
423
+ # Put _ in the end
424
+ keys = [key for key in keys if key[0] != '_'] + [key for key in keys if key[0] == '_']
425
+ for key in keys:
426
+ value = state['state'][key]
427
+ msg += f"\t {key}: {value} \n"
428
+ return msg
122
429
 
123
- @abstractmethod
124
- def _sample(self,N,Nb):
125
- pass
430
+ class ProposalBasedSampler(Sampler, ABC):
431
+ """ Abstract base class for samplers that use a proposal distribution. """
126
432
 
127
- @abstractmethod
128
- def _sample_adapt(self,N,Nb):
129
- pass
433
+ _STATE_KEYS = Sampler._STATE_KEYS.union({'current_target_logd', 'scale'})
130
434
 
131
- def _print_progress(self,s,Ns):
132
- """Prints sampling progress"""
133
- if Ns > 2:
134
- if (s % (max(Ns//100,1))) == 0:
135
- msg = f'Sample {s} / {Ns}'
136
- sys.stdout.write('\r'+msg)
137
- if s==Ns:
138
- msg = f'Sample {s} / {Ns}'
139
- sys.stdout.write('\r'+msg+'\n')
140
-
141
- def _call_callback(self, sample, sample_index):
142
- """ Calls the callback function. Assumes input is sample and sample index"""
143
- if self.callback is not None:
144
- self.callback(sample, sample_index)
435
+ def __init__(self, target=None, proposal=None, scale=1, **kwargs):
436
+ """ Initializer for abstract base class for samplers that use a proposal distribution.
437
+
438
+ Any subclassing samplers should simply store input parameters as part of the __init__ method.
145
439
 
146
- class ProposalBasedSampler(Sampler,ABC):
147
- def __init__(self, target, proposal=None, scale=1, x0=None, dim=None, **kwargs):
148
- #TODO: after fixing None dim
149
- #if dim is None and hasattr(proposal,'dim'):
150
- # dim = proposal.dim
151
- super().__init__(target, x0=x0, dim=dim, **kwargs)
440
+ Initialization of the sampler should be done in the _initialize method.
152
441
 
153
- self.proposal =proposal
154
- self.scale = scale
442
+ See :class:`Sampler` for additional details.
155
443
 
444
+ Parameters
445
+ ----------
446
+ target : cuqi.density.Density
447
+ The target density.
156
448
 
157
- @property
449
+ proposal : cuqi.distribution.Distribution, optional
450
+ The proposal distribution. If not specified, the default proposal is used.
451
+
452
+ scale : float, optional
453
+ The scale parameter for the proposal distribution.
454
+
455
+ **kwargs : dict
456
+ Additional keyword arguments passed to the :class:`Sampler` initializer.
457
+
458
+ """
459
+
460
+ super().__init__(target, **kwargs)
461
+ self.proposal = proposal
462
+ self.initial_scale = scale
463
+
464
+ def initialize(self):
465
+ """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
466
+
467
+ if self._is_initialized:
468
+ raise ValueError("Sampler is already initialized.")
469
+
470
+ if self.target is None:
471
+ raise ValueError("Cannot initialize sampler without a target density.")
472
+
473
+ # Default values
474
+ if self.initial_point is None:
475
+ self.initial_point = self._get_default_initial_point(self.dim)
476
+
477
+ if self.proposal is None:
478
+ self.proposal = self._default_proposal
479
+
480
+ # State variables
481
+ self.current_point = self.initial_point
482
+ self.scale = self.initial_scale
483
+
484
+ self.current_target_logd = self.target.logd(self.current_point)
485
+
486
+ # History variables
487
+ self._samples = []
488
+ self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
489
+
490
+ self._initialize() # Subclass specific initialization
491
+
492
+ self._validate_initialization()
493
+
494
+ self._is_initialized = True
495
+
496
+ @abstractmethod
497
+ def validate_proposal(self):
498
+ """ Validate the proposal distribution. """
499
+ pass
500
+
501
+ @property
502
+ def _default_proposal(self):
503
+ """ Return the default proposal distribution. Defaults to a Gaussian distribution with zero mean and unit variance. """
504
+ return cuqi.distribution.Gaussian(np.zeros(self.dim), 1)
505
+
506
+ @property
158
507
  def proposal(self):
159
- return self._proposal
508
+ """ The proposal distribution. """
509
+ return self._proposal
510
+
511
+ @proposal.setter
512
+ def proposal(self, proposal):
513
+ """ Set the proposal distribution. """
514
+ self._proposal = proposal
515
+ if self._proposal is not None:
516
+ self.validate_proposal()
517
+
160
518
 
161
- @proposal.setter
162
- def proposal(self, value):
163
- self._proposal = value
519
+ class _BatchHandler:
520
+ """ Utility class to handle batching of samples.
521
+
522
+ If a batch size is specified, this class will save samples to disk in batches of the specified size.
523
+
524
+ This is useful for very large sample sets that do not fit in memory.
525
+
526
+ """
527
+
528
+ def __init__(self, batch_size=0, sample_path='./CUQI_samples/'):
529
+
530
+ if batch_size < 0:
531
+ raise ValueError("Batch size should be a non-negative integer")
532
+
533
+ self.sample_path = sample_path
534
+ self._batch_size = batch_size
535
+ self.current_batch = []
536
+ self.num_batches_dumped = 0
164
537
 
165
538
  @property
166
- def geometry(self):
167
- geom1, geom2 = None, None
168
- if hasattr(self, 'proposal') and hasattr(self.proposal, 'geometry') and self.proposal.geometry.par_dim is not None:
169
- geom1= self.proposal.geometry
170
- if hasattr(self, 'target') and hasattr(self.target, 'geometry') and self.target.geometry.par_dim is not None:
171
- geom2 = self.target.geometry
172
- if not isinstance(geom1,cuqi.geometry._DefaultGeometry) and geom1 is not None:
173
- return geom1
174
- elif not isinstance(geom2,cuqi.geometry._DefaultGeometry) and geom2 is not None:
175
- return geom2
176
- else:
177
- return cuqi.geometry._DefaultGeometry1D(self.dim)
539
+ def sample_path(self):
540
+ """ The path to save the samples. """
541
+ return self._sample_path
542
+
543
+ @sample_path.setter
544
+ def sample_path(self, value):
545
+ if not isinstance(value, str):
546
+ raise TypeError("Sample path must be a string.")
547
+ normalized_path = value.rstrip('/') + '/'
548
+ if not os.path.isdir(normalized_path):
549
+ try:
550
+ os.makedirs(normalized_path, exist_ok=True)
551
+ except Exception as e:
552
+ raise ValueError(f"Could not create directory at {normalized_path}: {e}")
553
+ self._sample_path = normalized_path
554
+
555
+ def add_sample(self, sample):
556
+ """ Add a sample to the batch if batching. If the batch is full, flush the batch to disk. """
557
+
558
+ if self._batch_size <= 0:
559
+ return # Batching not used
560
+
561
+ self.current_batch.append(sample)
562
+
563
+ if len(self.current_batch) >= self._batch_size:
564
+ self.flush()
565
+
566
+ def flush(self):
567
+ """ Flush the current batch of samples to disk. """
568
+
569
+ if not self.current_batch:
570
+ return # No samples to flush
571
+
572
+ # Save the current batch of samples
573
+ batch_samples = np.array(self.current_batch)
574
+ file_path = f'{self.sample_path}batch_{self.num_batches_dumped:04d}.npz'
575
+ np.savez(file_path, samples=batch_samples, batch_id=self.num_batches_dumped)
576
+
577
+ self.num_batches_dumped += 1
578
+ self.current_batch = [] # Clear the batch after saving
579
+
580
+ def finalize(self):
581
+ """ Finalize the batch handler. Flush any remaining samples to disk. """
582
+ self.flush()