CUQIpy 1.0.0.post0.dev229__py3-none-any.whl → 1.0.0.post0.dev337__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -20,15 +20,35 @@ class SamplerNew(ABC):
20
20
 
21
21
  Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
22
22
 
23
+ The sampler maintains sets of state and history keys, which are used for features like checkpointing and resuming sampling.
24
+
25
+ The state of the sampler represents all variables that are updated (replaced) in a Markov Monte Carlo step, e.g. the current point of the sampler.
26
+
27
+ The history of the sampler represents all variables that are updated (appended) in a Markov Monte Carlo step, e.g. the samples and acceptance rates.
28
+
29
+ Subclasses should ensure that any new variables that are updated in a Markov Monte Carlo step are added to the state or history keys.
30
+
31
+ Saving and loading checkpoints saves and loads the state of the sampler (not the history).
32
+
33
+ Batching samples via the batch_size parameter saves the sampler history to disk in batches of the specified size.
34
+
35
+ Any other attribute stored as part of the sampler (e.g. target, initial_point) is not supposed to be updated
36
+ during sampling and should not be part of the state or history.
37
+
23
38
  """
39
+
24
40
  _STATE_KEYS = {'current_point'}
25
41
  """ Set of keys for the state dictionary. """
26
42
 
27
43
  _HISTORY_KEYS = {'_samples', '_acc'}
28
44
  """ Set of keys for the history dictionary. """
29
45
 
30
- def __init__(self, target: cuqi.density.Density, initial_point=None, callback=None):
46
+ def __init__(self, target:cuqi.density.Density=None, initial_point=None, callback=None):
31
47
  """ Initializer for abstract base class for all samplers.
48
+
49
+ Any subclassing samplers should simply store input parameters as part of the __init__ method.
50
+
51
+ The actual initialization of the sampler should be done in the _initialize method.
32
52
 
33
53
  Parameters
34
54
  ----------
@@ -45,26 +65,55 @@ class SamplerNew(ABC):
45
65
  """
46
66
 
47
67
  self.target = target
68
+ self.initial_point = initial_point
48
69
  self.callback = callback
70
+ self._is_initialized = False
49
71
 
50
- # Choose initial point if not given
51
- if initial_point is None:
52
- initial_point = np.ones(self.dim)
72
+ def initialize(self):
73
+ """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
53
74
 
54
- self.initial_point = initial_point
75
+ if self._is_initialized:
76
+ raise ValueError("Sampler is already initialized.")
55
77
 
56
- self._samples = [initial_point] # Remove. See #324.
78
+ if self.target is None:
79
+ raise ValueError("Cannot initialize sampler without a target density.")
80
+
81
+ # Default values
82
+ if self.initial_point is None:
83
+ self.initial_point = self._default_initial_point
84
+
85
+ # State variables
86
+ self.current_point = self.initial_point
87
+
88
+ # History variables
89
+ self._samples = []
90
+ self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
91
+
92
+ self._initialize() # Subclass specific initialization
93
+
94
+ self._validate_initialization()
95
+
96
+ self._is_initialized = True
57
97
 
58
98
  # ------------ Abstract methods to be implemented by subclasses ------------
59
-
60
99
  @abstractmethod
61
100
  def step(self):
62
101
  """ Perform one step of the sampler by transitioning the current point to a new point according to the sampler's transition kernel. """
63
102
  pass
64
103
 
65
104
  @abstractmethod
66
- def tune(self):
67
- """ Tune the parameters of the sampler. This method is called after each step of the warmup phase. """
105
+ def tune(self, skip_len, update_count):
106
+ """ Tune the parameters of the sampler. This method is called after each step of the warmup phase.
107
+
108
+ Parameters
109
+ ----------
110
+ skip_len : int
111
+ Defines the number of steps in between tuning (i.e. the tuning interval).
112
+
113
+ update_count : int
114
+ The number of times tuning has been performed. Can be used for internal bookkeeping.
115
+
116
+ """
68
117
  pass
69
118
 
70
119
  @abstractmethod
@@ -72,23 +121,19 @@ class SamplerNew(ABC):
72
121
  """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
73
122
  pass
74
123
 
75
- # -- _pre_sample and _pre_warmup methods: can be overridden by subclasses --
76
- def _pre_sample(self):
77
- """ Any code that needs to be run before sampling. """
78
- pass
79
-
80
- def _pre_warmup(self):
81
- """ Any code that needs to be run before warmup. """
124
+ @abstractmethod
125
+ def _initialize(self):
126
+ """ Subclass specific sampler initialization. Called during the initialization of the sampler which is done before sampling starts. """
82
127
  pass
83
128
 
84
129
  # ------------ Public attributes ------------
85
130
  @property
86
- def dim(self):
131
+ def dim(self) -> int:
87
132
  """ Dimension of the target density. """
88
133
  return self.target.dim
89
134
 
90
135
  @property
91
- def geometry(self):
136
+ def geometry(self) -> cuqi.geometry.Geometry:
92
137
  """ Geometry of the target density. """
93
138
  return self.target.geometry
94
139
 
@@ -101,39 +146,49 @@ class SamplerNew(ABC):
101
146
  def target(self, value):
102
147
  """ Set the target density. Runs validation of the target. """
103
148
  self._target = value
104
- self.validate_target()
105
-
106
- @property
107
- def current_point(self):
108
- """ The current point of the sampler. """
109
- return self._current_point
110
-
111
- @current_point.setter
112
- def current_point(self, value):
113
- """ Set the current point of the sampler. """
114
- self._current_point = value
149
+ if self._target is not None:
150
+ self.validate_target()
115
151
 
116
152
  # ------------ Public methods ------------
117
-
118
153
  def get_samples(self) -> Samples:
119
154
  """ Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """
120
155
  return Samples(np.array(self._samples).T, self.target.geometry)
121
156
 
122
- def reset(self): # TODO. Issue here. Current point is not reset, and initial point is lost with this reset.
123
- self._samples.clear()
124
- self._acc.clear()
157
+ def reinitialize(self):
158
+ """ Re-initialize the sampler. This clears the state and history and initializes the sampler again by setting state and history to their original values. """
159
+
160
+ # Loop over state and reset to None
161
+ for key in self._STATE_KEYS:
162
+ setattr(self, key, None)
163
+
164
+ # Loop over history and reset to None
165
+ for key in self._HISTORY_KEYS:
166
+ setattr(self, key, None)
167
+
168
+ self._is_initialized = False
169
+
170
+ self.initialize()
125
171
 
126
172
  def save_checkpoint(self, path):
127
173
  """ Save the state of the sampler to a file. """
128
174
 
175
+ self._ensure_initialized()
176
+
129
177
  state = self.get_state()
130
178
 
179
+ # Convert all CUQIarrays to numpy arrays since CUQIarrays do not get pickled correctly
180
+ for key, value in state['state'].items():
181
+ if isinstance(value, cuqi.array.CUQIarray):
182
+ state['state'][key] = value.to_numpy()
183
+
131
184
  with open(path, 'wb') as handle:
132
185
  pkl.dump(state, handle, protocol=pkl.HIGHEST_PROTOCOL)
133
186
 
134
187
  def load_checkpoint(self, path):
135
188
  """ Load the state of the sampler from a file. """
136
189
 
190
+ self._ensure_initialized()
191
+
137
192
  with open(path, 'rb') as handle:
138
193
  state = pkl.load(handle)
139
194
 
@@ -155,12 +210,14 @@ class SamplerNew(ABC):
155
210
 
156
211
  """
157
212
 
213
+ self._ensure_initialized()
214
+
158
215
  # Initialize batch handler
159
216
  if batch_size > 0:
160
217
  batch_handler = _BatchHandler(batch_size, sample_path)
161
218
 
162
219
  # Any code that needs to be run before sampling
163
- self._pre_sample()
220
+ if hasattr(self, "_pre_sample"): self._pre_sample()
164
221
 
165
222
  # Draw samples
166
223
  for _ in progressbar( range(Ns) ):
@@ -195,10 +252,12 @@ class SamplerNew(ABC):
195
252
 
196
253
  """
197
254
 
255
+ self._ensure_initialized()
256
+
198
257
  tune_interval = max(int(tune_freq * Nb), 1)
199
258
 
200
259
  # Any code that needs to be run before warmup
201
- self._pre_warmup()
260
+ if hasattr(self, "_pre_warmup"): self._pre_warmup()
202
261
 
203
262
  # Draw warmup samples with tuning
204
263
  for idx in progressbar(range(Nb)):
@@ -306,20 +365,61 @@ class SamplerNew(ABC):
306
365
  raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
307
366
 
308
367
  # ------------ Private methods ------------
309
-
310
368
  def _call_callback(self, sample, sample_index):
311
369
  """ Calls the callback function. Assumes input is sample and sample index"""
312
370
  if self.callback is not None:
313
371
  self.callback(sample, sample_index)
314
372
 
373
+ def _validate_initialization(self):
374
+ """ Validate the initialization of the sampler by checking all state and history keys are set. """
375
+
376
+ for key in self._STATE_KEYS:
377
+ if getattr(self, key) is None:
378
+ raise ValueError(f"Sampler state key {key} is not set after initialization.")
379
+
380
+ for key in self._HISTORY_KEYS:
381
+ if getattr(self, key) is None:
382
+ raise ValueError(f"Sampler history key {key} is not set after initialization.")
383
+
384
+ def _ensure_initialized(self):
385
+ """ Ensure the sampler is initialized. If not initialize it. """
386
+ if not self._is_initialized:
387
+ self.initialize()
388
+
389
+ @property
390
+ def _default_initial_point(self):
391
+ """ Return the default initial point for the sampler. Defaults to an array of ones. """
392
+ return np.ones(self.dim)
393
+
394
+ def __repr__(self):
395
+ """ Return a string representation of the sampler. """
396
+ if self.target is None:
397
+ return f"Sampler: {self.__class__.__name__} \n Target: None"
398
+ self._ensure_initialized()
399
+ state = self.get_state()
400
+ msg = f" Sampler: \n\t {self.__class__.__name__} \n Target: \n \t {self.target} \n Current state: \n"
401
+ # Sort keys alphabetically
402
+ keys = sorted(state['state'].keys())
403
+ # Put _ in the end
404
+ keys = [key for key in keys if key[0] != '_'] + [key for key in keys if key[0] == '_']
405
+ for key in keys:
406
+ value = state['state'][key]
407
+ msg += f"\t {key}: {value} \n"
408
+ return msg
315
409
 
316
410
  class ProposalBasedSamplerNew(SamplerNew, ABC):
317
411
  """ Abstract base class for samplers that use a proposal distribution. """
318
412
 
319
413
  _STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale'})
320
414
 
321
- def __init__(self, target, proposal=None, scale=1, **kwargs):
322
- """ Initializer for proposal based samplers.
415
+ def __init__(self, target=None, proposal=None, scale=1, **kwargs):
416
+ """ Initializer for abstract base class for samplers that use a proposal distribution.
417
+
418
+ Any subclassing samplers should simply store input parameters as part of the __init__ method.
419
+
420
+ Initialization of the sampler should be done in the _initialize method.
421
+
422
+ See :class:`SamplerNew` for additional details.
323
423
 
324
424
  Parameters
325
425
  ----------
@@ -338,35 +438,62 @@ class ProposalBasedSamplerNew(SamplerNew, ABC):
338
438
  """
339
439
 
340
440
  super().__init__(target, **kwargs)
441
+ self.proposal = proposal
442
+ self.initial_scale = scale
443
+
444
+ def initialize(self):
445
+ """ Initialize the sampler by setting and allocating the state and history before sampling starts. """
341
446
 
447
+ if self._is_initialized:
448
+ raise ValueError("Sampler is already initialized.")
449
+
450
+ if self.target is None:
451
+ raise ValueError("Cannot initialize sampler without a target density.")
452
+
453
+ # Default values
454
+ if self.initial_point is None:
455
+ self.initial_point = self._default_initial_point
456
+
457
+ if self.proposal is None:
458
+ self.proposal = self._default_proposal
459
+
460
+ # State variables
342
461
  self.current_point = self.initial_point
462
+ self.scale = self.initial_scale
463
+
343
464
  self.current_target_logd = self.target.logd(self.current_point)
344
- self.proposal = proposal
345
- self.scale = scale
346
465
 
347
- self._acc = [ 1 ] # TODO. Check
466
+ # History variables
467
+ self._samples = []
468
+ self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
348
469
 
349
- @property
350
- def proposal(self):
351
- return self._proposal
470
+ self._initialize() # Subclass specific initialization
471
+
472
+ self._validate_initialization()
473
+
474
+ self._is_initialized = True
475
+
476
+ @abstractmethod
477
+ def validate_proposal(self):
478
+ """ Validate the proposal distribution. """
479
+ pass
352
480
 
353
- @proposal.setter
354
- def proposal(self, value):
355
- self._proposal = value
481
+ @property
482
+ def _default_proposal(self):
483
+ """ Return the default proposal distribution. Defaults to a Gaussian distribution with zero mean and unit variance. """
484
+ return cuqi.distribution.Gaussian(np.zeros(self.dim), 1)
356
485
 
357
486
  @property
358
- def geometry(self): # TODO. Check if we can refactor this
359
- geom1, geom2 = None, None
360
- if hasattr(self, 'proposal') and hasattr(self.proposal, 'geometry') and self.proposal.geometry.par_dim is not None:
361
- geom1= self.proposal.geometry
362
- if hasattr(self, 'target') and hasattr(self.target, 'geometry') and self.target.geometry.par_dim is not None:
363
- geom2 = self.target.geometry
364
- if not isinstance(geom1,cuqi.geometry._DefaultGeometry) and geom1 is not None:
365
- return geom1
366
- elif not isinstance(geom2,cuqi.geometry._DefaultGeometry) and geom2 is not None:
367
- return geom2
368
- else:
369
- return cuqi.geometry._DefaultGeometry(self.dim)
487
+ def proposal(self):
488
+ """ The proposal distribution. """
489
+ return self._proposal
490
+
491
+ @proposal.setter
492
+ def proposal(self, proposal):
493
+ """ Set the proposal distribution. """
494
+ self._proposal = proposal
495
+ if self._proposal is not None:
496
+ self.validate_proposal()
370
497
 
371
498
 
372
499
  class _BatchHandler: