CUQIpy 1.0.0.post0.dev229__py3-none-any.whl → 1.0.0.post0.dev337__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- {CUQIpy-1.0.0.post0.dev229.dist-info → CUQIpy-1.0.0.post0.dev337.dist-info}/METADATA +1 -1
- {CUQIpy-1.0.0.post0.dev229.dist-info → CUQIpy-1.0.0.post0.dev337.dist-info}/RECORD +19 -15
- cuqi/_version.py +3 -3
- cuqi/experimental/mcmc/__init__.py +4 -0
- cuqi/experimental/mcmc/_conjugate.py +77 -0
- cuqi/experimental/mcmc/_conjugate_approx.py +75 -0
- cuqi/experimental/mcmc/_cwmh.py +29 -42
- cuqi/experimental/mcmc/_direct.py +28 -0
- cuqi/experimental/mcmc/_gibbs.py +267 -0
- cuqi/experimental/mcmc/_hmc.py +34 -34
- cuqi/experimental/mcmc/_langevin_algorithm.py +5 -4
- cuqi/experimental/mcmc/_laplace_approximation.py +11 -13
- cuqi/experimental/mcmc/_mh.py +9 -16
- cuqi/experimental/mcmc/_pcn.py +14 -34
- cuqi/experimental/mcmc/_rto.py +25 -52
- cuqi/experimental/mcmc/_sampler.py +186 -59
- {CUQIpy-1.0.0.post0.dev229.dist-info → CUQIpy-1.0.0.post0.dev337.dist-info}/LICENSE +0 -0
- {CUQIpy-1.0.0.post0.dev229.dist-info → CUQIpy-1.0.0.post0.dev337.dist-info}/WHEEL +0 -0
- {CUQIpy-1.0.0.post0.dev229.dist-info → CUQIpy-1.0.0.post0.dev337.dist-info}/top_level.txt +0 -0
|
@@ -20,15 +20,35 @@ class SamplerNew(ABC):
|
|
|
20
20
|
|
|
21
21
|
Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
|
|
22
22
|
|
|
23
|
+
The sampler maintains sets of state and history keys, which are used for features like checkpointing and resuming sampling.
|
|
24
|
+
|
|
25
|
+
The state of the sampler represents all variables that are updated (replaced) in a Markov Monte Carlo step, e.g. the current point of the sampler.
|
|
26
|
+
|
|
27
|
+
The history of the sampler represents all variables that are updated (appended) in a Markov Monte Carlo step, e.g. the samples and acceptance rates.
|
|
28
|
+
|
|
29
|
+
Subclasses should ensure that any new variables that are updated in a Markov Monte Carlo step are added to the state or history keys.
|
|
30
|
+
|
|
31
|
+
Saving and loading checkpoints saves and loads the state of the sampler (not the history).
|
|
32
|
+
|
|
33
|
+
Batching samples via the batch_size parameter saves the sampler history to disk in batches of the specified size.
|
|
34
|
+
|
|
35
|
+
Any other attribute stored as part of the sampler (e.g. target, initial_point) is not supposed to be updated
|
|
36
|
+
during sampling and should not be part of the state or history.
|
|
37
|
+
|
|
23
38
|
"""
|
|
39
|
+
|
|
24
40
|
_STATE_KEYS = {'current_point'}
|
|
25
41
|
""" Set of keys for the state dictionary. """
|
|
26
42
|
|
|
27
43
|
_HISTORY_KEYS = {'_samples', '_acc'}
|
|
28
44
|
""" Set of keys for the history dictionary. """
|
|
29
45
|
|
|
30
|
-
def __init__(self, target:
|
|
46
|
+
def __init__(self, target:cuqi.density.Density=None, initial_point=None, callback=None):
|
|
31
47
|
""" Initializer for abstract base class for all samplers.
|
|
48
|
+
|
|
49
|
+
Any subclassing samplers should simply store input parameters as part of the __init__ method.
|
|
50
|
+
|
|
51
|
+
The actual initialization of the sampler should be done in the _initialize method.
|
|
32
52
|
|
|
33
53
|
Parameters
|
|
34
54
|
----------
|
|
@@ -45,26 +65,55 @@ class SamplerNew(ABC):
|
|
|
45
65
|
"""
|
|
46
66
|
|
|
47
67
|
self.target = target
|
|
68
|
+
self.initial_point = initial_point
|
|
48
69
|
self.callback = callback
|
|
70
|
+
self._is_initialized = False
|
|
49
71
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
initial_point = np.ones(self.dim)
|
|
72
|
+
def initialize(self):
|
|
73
|
+
""" Initialize the sampler by setting and allocating the state and history before sampling starts. """
|
|
53
74
|
|
|
54
|
-
self.
|
|
75
|
+
if self._is_initialized:
|
|
76
|
+
raise ValueError("Sampler is already initialized.")
|
|
55
77
|
|
|
56
|
-
self.
|
|
78
|
+
if self.target is None:
|
|
79
|
+
raise ValueError("Cannot initialize sampler without a target density.")
|
|
80
|
+
|
|
81
|
+
# Default values
|
|
82
|
+
if self.initial_point is None:
|
|
83
|
+
self.initial_point = self._default_initial_point
|
|
84
|
+
|
|
85
|
+
# State variables
|
|
86
|
+
self.current_point = self.initial_point
|
|
87
|
+
|
|
88
|
+
# History variables
|
|
89
|
+
self._samples = []
|
|
90
|
+
self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
|
|
91
|
+
|
|
92
|
+
self._initialize() # Subclass specific initialization
|
|
93
|
+
|
|
94
|
+
self._validate_initialization()
|
|
95
|
+
|
|
96
|
+
self._is_initialized = True
|
|
57
97
|
|
|
58
98
|
# ------------ Abstract methods to be implemented by subclasses ------------
|
|
59
|
-
|
|
60
99
|
@abstractmethod
|
|
61
100
|
def step(self):
|
|
62
101
|
""" Perform one step of the sampler by transitioning the current point to a new point according to the sampler's transition kernel. """
|
|
63
102
|
pass
|
|
64
103
|
|
|
65
104
|
@abstractmethod
|
|
66
|
-
def tune(self):
|
|
67
|
-
""" Tune the parameters of the sampler. This method is called after each step of the warmup phase.
|
|
105
|
+
def tune(self, skip_len, update_count):
|
|
106
|
+
""" Tune the parameters of the sampler. This method is called after each step of the warmup phase.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
skip_len : int
|
|
111
|
+
Defines the number of steps in between tuning (i.e. the tuning interval).
|
|
112
|
+
|
|
113
|
+
update_count : int
|
|
114
|
+
The number of times tuning has been performed. Can be used for internal bookkeeping.
|
|
115
|
+
|
|
116
|
+
"""
|
|
68
117
|
pass
|
|
69
118
|
|
|
70
119
|
@abstractmethod
|
|
@@ -72,23 +121,19 @@ class SamplerNew(ABC):
|
|
|
72
121
|
""" Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
|
|
73
122
|
pass
|
|
74
123
|
|
|
75
|
-
|
|
76
|
-
def
|
|
77
|
-
"""
|
|
78
|
-
pass
|
|
79
|
-
|
|
80
|
-
def _pre_warmup(self):
|
|
81
|
-
""" Any code that needs to be run before warmup. """
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def _initialize(self):
|
|
126
|
+
""" Subclass specific sampler initialization. Called during the initialization of the sampler which is done before sampling starts. """
|
|
82
127
|
pass
|
|
83
128
|
|
|
84
129
|
# ------------ Public attributes ------------
|
|
85
130
|
@property
|
|
86
|
-
def dim(self):
|
|
131
|
+
def dim(self) -> int:
|
|
87
132
|
""" Dimension of the target density. """
|
|
88
133
|
return self.target.dim
|
|
89
134
|
|
|
90
135
|
@property
|
|
91
|
-
def geometry(self):
|
|
136
|
+
def geometry(self) -> cuqi.geometry.Geometry:
|
|
92
137
|
""" Geometry of the target density. """
|
|
93
138
|
return self.target.geometry
|
|
94
139
|
|
|
@@ -101,39 +146,49 @@ class SamplerNew(ABC):
|
|
|
101
146
|
def target(self, value):
|
|
102
147
|
""" Set the target density. Runs validation of the target. """
|
|
103
148
|
self._target = value
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
@property
|
|
107
|
-
def current_point(self):
|
|
108
|
-
""" The current point of the sampler. """
|
|
109
|
-
return self._current_point
|
|
110
|
-
|
|
111
|
-
@current_point.setter
|
|
112
|
-
def current_point(self, value):
|
|
113
|
-
""" Set the current point of the sampler. """
|
|
114
|
-
self._current_point = value
|
|
149
|
+
if self._target is not None:
|
|
150
|
+
self.validate_target()
|
|
115
151
|
|
|
116
152
|
# ------------ Public methods ------------
|
|
117
|
-
|
|
118
153
|
def get_samples(self) -> Samples:
|
|
119
154
|
""" Return the samples. The internal data-structure for the samples is a dynamic list so this creates a copy. """
|
|
120
155
|
return Samples(np.array(self._samples).T, self.target.geometry)
|
|
121
156
|
|
|
122
|
-
def
|
|
123
|
-
|
|
124
|
-
|
|
157
|
+
def reinitialize(self):
|
|
158
|
+
""" Re-initialize the sampler. This clears the state and history and initializes the sampler again by setting state and history to their original values. """
|
|
159
|
+
|
|
160
|
+
# Loop over state and reset to None
|
|
161
|
+
for key in self._STATE_KEYS:
|
|
162
|
+
setattr(self, key, None)
|
|
163
|
+
|
|
164
|
+
# Loop over history and reset to None
|
|
165
|
+
for key in self._HISTORY_KEYS:
|
|
166
|
+
setattr(self, key, None)
|
|
167
|
+
|
|
168
|
+
self._is_initialized = False
|
|
169
|
+
|
|
170
|
+
self.initialize()
|
|
125
171
|
|
|
126
172
|
def save_checkpoint(self, path):
|
|
127
173
|
""" Save the state of the sampler to a file. """
|
|
128
174
|
|
|
175
|
+
self._ensure_initialized()
|
|
176
|
+
|
|
129
177
|
state = self.get_state()
|
|
130
178
|
|
|
179
|
+
# Convert all CUQIarrays to numpy arrays since CUQIarrays do not get pickled correctly
|
|
180
|
+
for key, value in state['state'].items():
|
|
181
|
+
if isinstance(value, cuqi.array.CUQIarray):
|
|
182
|
+
state['state'][key] = value.to_numpy()
|
|
183
|
+
|
|
131
184
|
with open(path, 'wb') as handle:
|
|
132
185
|
pkl.dump(state, handle, protocol=pkl.HIGHEST_PROTOCOL)
|
|
133
186
|
|
|
134
187
|
def load_checkpoint(self, path):
|
|
135
188
|
""" Load the state of the sampler from a file. """
|
|
136
189
|
|
|
190
|
+
self._ensure_initialized()
|
|
191
|
+
|
|
137
192
|
with open(path, 'rb') as handle:
|
|
138
193
|
state = pkl.load(handle)
|
|
139
194
|
|
|
@@ -155,12 +210,14 @@ class SamplerNew(ABC):
|
|
|
155
210
|
|
|
156
211
|
"""
|
|
157
212
|
|
|
213
|
+
self._ensure_initialized()
|
|
214
|
+
|
|
158
215
|
# Initialize batch handler
|
|
159
216
|
if batch_size > 0:
|
|
160
217
|
batch_handler = _BatchHandler(batch_size, sample_path)
|
|
161
218
|
|
|
162
219
|
# Any code that needs to be run before sampling
|
|
163
|
-
self._pre_sample()
|
|
220
|
+
if hasattr(self, "_pre_sample"): self._pre_sample()
|
|
164
221
|
|
|
165
222
|
# Draw samples
|
|
166
223
|
for _ in progressbar( range(Ns) ):
|
|
@@ -195,10 +252,12 @@ class SamplerNew(ABC):
|
|
|
195
252
|
|
|
196
253
|
"""
|
|
197
254
|
|
|
255
|
+
self._ensure_initialized()
|
|
256
|
+
|
|
198
257
|
tune_interval = max(int(tune_freq * Nb), 1)
|
|
199
258
|
|
|
200
259
|
# Any code that needs to be run before warmup
|
|
201
|
-
self._pre_warmup()
|
|
260
|
+
if hasattr(self, "_pre_warmup"): self._pre_warmup()
|
|
202
261
|
|
|
203
262
|
# Draw warmup samples with tuning
|
|
204
263
|
for idx in progressbar(range(Nb)):
|
|
@@ -306,20 +365,61 @@ class SamplerNew(ABC):
|
|
|
306
365
|
raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
|
|
307
366
|
|
|
308
367
|
# ------------ Private methods ------------
|
|
309
|
-
|
|
310
368
|
def _call_callback(self, sample, sample_index):
|
|
311
369
|
""" Calls the callback function. Assumes input is sample and sample index"""
|
|
312
370
|
if self.callback is not None:
|
|
313
371
|
self.callback(sample, sample_index)
|
|
314
372
|
|
|
373
|
+
def _validate_initialization(self):
|
|
374
|
+
""" Validate the initialization of the sampler by checking all state and history keys are set. """
|
|
375
|
+
|
|
376
|
+
for key in self._STATE_KEYS:
|
|
377
|
+
if getattr(self, key) is None:
|
|
378
|
+
raise ValueError(f"Sampler state key {key} is not set after initialization.")
|
|
379
|
+
|
|
380
|
+
for key in self._HISTORY_KEYS:
|
|
381
|
+
if getattr(self, key) is None:
|
|
382
|
+
raise ValueError(f"Sampler history key {key} is not set after initialization.")
|
|
383
|
+
|
|
384
|
+
def _ensure_initialized(self):
|
|
385
|
+
""" Ensure the sampler is initialized. If not initialize it. """
|
|
386
|
+
if not self._is_initialized:
|
|
387
|
+
self.initialize()
|
|
388
|
+
|
|
389
|
+
@property
|
|
390
|
+
def _default_initial_point(self):
|
|
391
|
+
""" Return the default initial point for the sampler. Defaults to an array of ones. """
|
|
392
|
+
return np.ones(self.dim)
|
|
393
|
+
|
|
394
|
+
def __repr__(self):
|
|
395
|
+
""" Return a string representation of the sampler. """
|
|
396
|
+
if self.target is None:
|
|
397
|
+
return f"Sampler: {self.__class__.__name__} \n Target: None"
|
|
398
|
+
self._ensure_initialized()
|
|
399
|
+
state = self.get_state()
|
|
400
|
+
msg = f" Sampler: \n\t {self.__class__.__name__} \n Target: \n \t {self.target} \n Current state: \n"
|
|
401
|
+
# Sort keys alphabetically
|
|
402
|
+
keys = sorted(state['state'].keys())
|
|
403
|
+
# Put _ in the end
|
|
404
|
+
keys = [key for key in keys if key[0] != '_'] + [key for key in keys if key[0] == '_']
|
|
405
|
+
for key in keys:
|
|
406
|
+
value = state['state'][key]
|
|
407
|
+
msg += f"\t {key}: {value} \n"
|
|
408
|
+
return msg
|
|
315
409
|
|
|
316
410
|
class ProposalBasedSamplerNew(SamplerNew, ABC):
|
|
317
411
|
""" Abstract base class for samplers that use a proposal distribution. """
|
|
318
412
|
|
|
319
413
|
_STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale'})
|
|
320
414
|
|
|
321
|
-
def __init__(self, target, proposal=None, scale=1, **kwargs):
|
|
322
|
-
""" Initializer for
|
|
415
|
+
def __init__(self, target=None, proposal=None, scale=1, **kwargs):
|
|
416
|
+
""" Initializer for abstract base class for samplers that use a proposal distribution.
|
|
417
|
+
|
|
418
|
+
Any subclassing samplers should simply store input parameters as part of the __init__ method.
|
|
419
|
+
|
|
420
|
+
Initialization of the sampler should be done in the _initialize method.
|
|
421
|
+
|
|
422
|
+
See :class:`SamplerNew` for additional details.
|
|
323
423
|
|
|
324
424
|
Parameters
|
|
325
425
|
----------
|
|
@@ -338,35 +438,62 @@ class ProposalBasedSamplerNew(SamplerNew, ABC):
|
|
|
338
438
|
"""
|
|
339
439
|
|
|
340
440
|
super().__init__(target, **kwargs)
|
|
441
|
+
self.proposal = proposal
|
|
442
|
+
self.initial_scale = scale
|
|
443
|
+
|
|
444
|
+
def initialize(self):
|
|
445
|
+
""" Initialize the sampler by setting and allocating the state and history before sampling starts. """
|
|
341
446
|
|
|
447
|
+
if self._is_initialized:
|
|
448
|
+
raise ValueError("Sampler is already initialized.")
|
|
449
|
+
|
|
450
|
+
if self.target is None:
|
|
451
|
+
raise ValueError("Cannot initialize sampler without a target density.")
|
|
452
|
+
|
|
453
|
+
# Default values
|
|
454
|
+
if self.initial_point is None:
|
|
455
|
+
self.initial_point = self._default_initial_point
|
|
456
|
+
|
|
457
|
+
if self.proposal is None:
|
|
458
|
+
self.proposal = self._default_proposal
|
|
459
|
+
|
|
460
|
+
# State variables
|
|
342
461
|
self.current_point = self.initial_point
|
|
462
|
+
self.scale = self.initial_scale
|
|
463
|
+
|
|
343
464
|
self.current_target_logd = self.target.logd(self.current_point)
|
|
344
|
-
self.proposal = proposal
|
|
345
|
-
self.scale = scale
|
|
346
465
|
|
|
347
|
-
|
|
466
|
+
# History variables
|
|
467
|
+
self._samples = []
|
|
468
|
+
self._acc = [ 1 ] # TODO. Check if we need to put 1 here.
|
|
348
469
|
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
470
|
+
self._initialize() # Subclass specific initialization
|
|
471
|
+
|
|
472
|
+
self._validate_initialization()
|
|
473
|
+
|
|
474
|
+
self._is_initialized = True
|
|
475
|
+
|
|
476
|
+
@abstractmethod
|
|
477
|
+
def validate_proposal(self):
|
|
478
|
+
""" Validate the proposal distribution. """
|
|
479
|
+
pass
|
|
352
480
|
|
|
353
|
-
@
|
|
354
|
-
def
|
|
355
|
-
|
|
481
|
+
@property
|
|
482
|
+
def _default_proposal(self):
|
|
483
|
+
""" Return the default proposal distribution. Defaults to a Gaussian distribution with zero mean and unit variance. """
|
|
484
|
+
return cuqi.distribution.Gaussian(np.zeros(self.dim), 1)
|
|
356
485
|
|
|
357
486
|
@property
|
|
358
|
-
def
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
else:
|
|
369
|
-
return cuqi.geometry._DefaultGeometry(self.dim)
|
|
487
|
+
def proposal(self):
|
|
488
|
+
""" The proposal distribution. """
|
|
489
|
+
return self._proposal
|
|
490
|
+
|
|
491
|
+
@proposal.setter
|
|
492
|
+
def proposal(self, proposal):
|
|
493
|
+
""" Set the proposal distribution. """
|
|
494
|
+
self._proposal = proposal
|
|
495
|
+
if self._proposal is not None:
|
|
496
|
+
self.validate_proposal()
|
|
370
497
|
|
|
371
498
|
|
|
372
499
|
class _BatchHandler:
|
|
File without changes
|
|
File without changes
|
|
File without changes
|