CUQIpy 1.0.0.post0.dev127__py3-none-any.whl → 1.0.0.post0.dev180__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- {CUQIpy-1.0.0.post0.dev127.dist-info → CUQIpy-1.0.0.post0.dev180.dist-info}/METADATA +1 -1
- {CUQIpy-1.0.0.post0.dev127.dist-info → CUQIpy-1.0.0.post0.dev180.dist-info}/RECORD +15 -14
- cuqi/_version.py +3 -3
- cuqi/experimental/mcmc/__init__.py +1 -0
- cuqi/experimental/mcmc/_cwmh.py +2 -22
- cuqi/experimental/mcmc/_hmc.py +470 -0
- cuqi/experimental/mcmc/_langevin_algorithm.py +12 -41
- cuqi/experimental/mcmc/_mh.py +4 -12
- cuqi/experimental/mcmc/_pcn.py +5 -15
- cuqi/experimental/mcmc/_rto.py +2 -7
- cuqi/experimental/mcmc/_sampler.py +120 -20
- cuqi/sampler/_hmc.py +7 -3
- {CUQIpy-1.0.0.post0.dev127.dist-info → CUQIpy-1.0.0.post0.dev180.dist-info}/LICENSE +0 -0
- {CUQIpy-1.0.0.post0.dev127.dist-info → CUQIpy-1.0.0.post0.dev180.dist-info}/WHEEL +0 -0
- {CUQIpy-1.0.0.post0.dev127.dist-info → CUQIpy-1.0.0.post0.dev180.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.0.0.post0.
|
|
3
|
+
Version: 1.0.0.post0.dev180
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
|
|
2
2
|
cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
|
|
3
|
-
cuqi/_version.py,sha256=
|
|
3
|
+
cuqi/_version.py,sha256=nZ9Rz10NtUnQZt09kW9YMygjxV0G-qsipgAL3xBFeAU,510
|
|
4
4
|
cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
|
|
5
5
|
cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
|
|
6
6
|
cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
|
|
@@ -32,13 +32,14 @@ cuqi/distribution/_normal.py,sha256=UeoTtGDT7YSf4ZNo2amlVF9K-YQpYbf8q76jcRJTVFw,
|
|
|
32
32
|
cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7ig,5018
|
|
33
33
|
cuqi/distribution/_uniform.py,sha256=7xJmCZH_LPhuGkwEDGh-_CTtzcWKrXMOxtTJUFb7Ydo,1607
|
|
34
34
|
cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
|
|
35
|
-
cuqi/experimental/mcmc/__init__.py,sha256=
|
|
36
|
-
cuqi/experimental/mcmc/_cwmh.py,sha256=
|
|
37
|
-
cuqi/experimental/mcmc/
|
|
38
|
-
cuqi/experimental/mcmc/
|
|
39
|
-
cuqi/experimental/mcmc/
|
|
40
|
-
cuqi/experimental/mcmc/
|
|
41
|
-
cuqi/experimental/mcmc/
|
|
35
|
+
cuqi/experimental/mcmc/__init__.py,sha256=S4aXYpnO75HQcwDYfr1-ki8UvlenPDXxshES5avtBF0,340
|
|
36
|
+
cuqi/experimental/mcmc/_cwmh.py,sha256=yRlTk5a1QYfH3JyCecfOOTeDf-4-tmJ3Tl2Bc3pyp1Y,7336
|
|
37
|
+
cuqi/experimental/mcmc/_hmc.py,sha256=qqAyoAajLE_JenYMgAbD3tknuEf75AJu-ufF69GKGk4,19384
|
|
38
|
+
cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=MX48u3GYgCckB6Q5h5kXr_qdIaLQH2toOG5u29OY7gk,8245
|
|
39
|
+
cuqi/experimental/mcmc/_mh.py,sha256=aIV1Ntq0EAq3QJ1_X-DbP7eDAL-d_Or7d3RUO-R48I4,3090
|
|
40
|
+
cuqi/experimental/mcmc/_pcn.py,sha256=3M8zhQGQa53Gz04AkC8wJM61_5rIjGVnhPefi8m4dbY,3531
|
|
41
|
+
cuqi/experimental/mcmc/_rto.py,sha256=jSPznr34XPfWM6LysWIiN4hE-vtyti3cHyvzy9ruykg,11349
|
|
42
|
+
cuqi/experimental/mcmc/_sampler.py,sha256=_5Uo2F-Mta46w3lo7WBVNwvTLYhES_BzMTJrKxA00c8,14861
|
|
42
43
|
cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
|
|
43
44
|
cuqi/geometry/_geometry.py,sha256=WYFC-4_VBTW73b2ldsnfGYKvdSiCE8plr89xTSmkadg,46804
|
|
44
45
|
cuqi/implicitprior/__init__.py,sha256=ZRZ9fgxgEl5n0A9F7WCl1_jid-GUiC8ZLkyTmGQmFlY,100
|
|
@@ -59,7 +60,7 @@ cuqi/sampler/_conjugate.py,sha256=ztmUR3V3qZk9zelKx48ULnmMs_zKTDUfohc256VOIe8,27
|
|
|
59
60
|
cuqi/sampler/_conjugate_approx.py,sha256=xX-X71EgxGnZooOY6CIBhuJTs3dhcKfoLnoFxX3CO2g,1938
|
|
60
61
|
cuqi/sampler/_cwmh.py,sha256=VlAVT1SXQU0yD5ZeR-_ckWvX-ifJrMweFFdFbxdfB_k,7775
|
|
61
62
|
cuqi/sampler/_gibbs.py,sha256=N7qcePwMkRtxINN5JF0FaMIdDCXZGqsfKjfha_KHCck,8627
|
|
62
|
-
cuqi/sampler/_hmc.py,sha256=
|
|
63
|
+
cuqi/sampler/_hmc.py,sha256=EUTefZir-wapoZ7OZFb5M5vayL8z6XksZRMY1BpbuXc,15027
|
|
63
64
|
cuqi/sampler/_langevin_algorithm.py,sha256=o5EyvaR6QGAD7LKwXVRC3WwAP5IYJf5GoMVWl9DrfOA,7861
|
|
64
65
|
cuqi/sampler/_laplace_approximation.py,sha256=u018Z5eqlcq_cIwD9yNOaA15dLQE_vUWaee5Xp8bcjg,6454
|
|
65
66
|
cuqi/sampler/_mh.py,sha256=V5tIdn-KdfWo4J_Nbf-AH6XwKWblWUyc4BeuSikUHsE,7062
|
|
@@ -75,8 +76,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
|
|
|
75
76
|
cuqi/utilities/__init__.py,sha256=EfxHLdsyDNugbmbzs43nV_AeKcycM9sVBjG9WZydagA,351
|
|
76
77
|
cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
|
|
77
78
|
cuqi/utilities/_utilities.py,sha256=At3DOXRdF3GwLkVcM2FXooGyjAGfPkIM0bRzhTfLmWk,8046
|
|
78
|
-
CUQIpy-1.0.0.post0.
|
|
79
|
-
CUQIpy-1.0.0.post0.
|
|
80
|
-
CUQIpy-1.0.0.post0.
|
|
81
|
-
CUQIpy-1.0.0.post0.
|
|
82
|
-
CUQIpy-1.0.0.post0.
|
|
79
|
+
CUQIpy-1.0.0.post0.dev180.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
|
|
80
|
+
CUQIpy-1.0.0.post0.dev180.dist-info/METADATA,sha256=YSVheKM46Y6Al3zXEu11wlCkJxAasZw2onXB0nTLN08,18393
|
|
81
|
+
CUQIpy-1.0.0.post0.dev180.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
82
|
+
CUQIpy-1.0.0.post0.dev180.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
|
|
83
|
+
CUQIpy-1.0.0.post0.dev180.dist-info/RECORD,,
|
cuqi/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-04-
|
|
11
|
+
"date": "2024-04-22T18:01:04+0200",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.0.0.post0.
|
|
14
|
+
"full-revisionid": "be4b485322b1be52f78dfe6d03694d6ba2000b11",
|
|
15
|
+
"version": "1.0.0.post0.dev180"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
cuqi/experimental/mcmc/_cwmh.py
CHANGED
|
@@ -144,7 +144,7 @@ class CWMHNew(ProposalBasedSamplerNew):
|
|
|
144
144
|
|
|
145
145
|
# Propose a sample x_all_components from the proposal distribution
|
|
146
146
|
# for all the components
|
|
147
|
-
target_eval_t = self.
|
|
147
|
+
target_eval_t = self.current_target_logd
|
|
148
148
|
if isinstance(self.proposal,cuqi.distribution.Distribution):
|
|
149
149
|
x_all_components = self.proposal(
|
|
150
150
|
location= self.current_point, scale=self.scale).sample()
|
|
@@ -175,7 +175,7 @@ class CWMHNew(ProposalBasedSamplerNew):
|
|
|
175
175
|
|
|
176
176
|
x_star = x_t.copy()
|
|
177
177
|
|
|
178
|
-
self.
|
|
178
|
+
self.current_target_logd = target_eval_t
|
|
179
179
|
self.current_point = x_t
|
|
180
180
|
|
|
181
181
|
return acc
|
|
@@ -199,23 +199,3 @@ class CWMHNew(ProposalBasedSamplerNew):
|
|
|
199
199
|
# Update the scale parameter
|
|
200
200
|
self.scale = np.minimum(scale_temp, np.ones(self.dim))
|
|
201
201
|
self._scale_temp = scale_temp
|
|
202
|
-
|
|
203
|
-
def get_state(self):
|
|
204
|
-
current_point = self.current_point
|
|
205
|
-
if isinstance(current_point, CUQIarray):
|
|
206
|
-
current_point = current_point.to_numpy()
|
|
207
|
-
|
|
208
|
-
return {'sampler_type': 'CWMH',
|
|
209
|
-
'current_point': current_point,
|
|
210
|
-
'current_target': self.current_target,
|
|
211
|
-
'scale': self.scale}
|
|
212
|
-
|
|
213
|
-
def set_state(self, state):
|
|
214
|
-
current_point = state['current_point']
|
|
215
|
-
if not isinstance(current_point, CUQIarray):
|
|
216
|
-
current_point = CUQIarray(current_point,
|
|
217
|
-
geometry=self.target.geometry)
|
|
218
|
-
|
|
219
|
-
self.current_point = current_point
|
|
220
|
-
self.current_target = state['current_target']
|
|
221
|
-
self.scale = state['scale']
|
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import numpy as np
|
|
3
|
+
from cuqi.experimental.mcmc import SamplerNew
|
|
4
|
+
from cuqi.array import CUQIarray
|
|
5
|
+
from numbers import Number
|
|
6
|
+
|
|
7
|
+
class NUTSNew(SamplerNew):
|
|
8
|
+
"""No-U-Turn Sampler (Hoffman and Gelman, 2014).
|
|
9
|
+
|
|
10
|
+
Samples a distribution given its logpdf and gradient using a Hamiltonian
|
|
11
|
+
Monte Carlo (HMC) algorithm with automatic parameter tuning.
|
|
12
|
+
|
|
13
|
+
For more details see: See Hoffman, M. D., & Gelman, A. (2014). The no-U-turn
|
|
14
|
+
sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. Journal
|
|
15
|
+
of Machine Learning Research, 15, 1593-1623.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
target : `cuqi.distribution.Distribution`
|
|
20
|
+
The target distribution to sample. Must have logpdf and gradient method.
|
|
21
|
+
Custom logpdfs and gradients are supported by using a
|
|
22
|
+
:class:`cuqi.distribution.UserDefinedDistribution`.
|
|
23
|
+
|
|
24
|
+
initial_point : ndarray
|
|
25
|
+
Initial parameters. *Optional*. If not provided, the initial point is
|
|
26
|
+
an array of ones.
|
|
27
|
+
|
|
28
|
+
max_depth : int
|
|
29
|
+
Maximum depth of the tree >=0 and the default is 15.
|
|
30
|
+
|
|
31
|
+
step_size : None or float
|
|
32
|
+
If step_size is provided (as positive float), it will be used as initial
|
|
33
|
+
step size. If None, the step size will be estimated by the sampler.
|
|
34
|
+
|
|
35
|
+
opt_acc_rate : float
|
|
36
|
+
The optimal acceptance rate to reach if using adaptive step size.
|
|
37
|
+
Suggested values are 0.6 (default) or 0.8 (as in stan). In principle,
|
|
38
|
+
opt_acc_rate should be in (0, 1), however, choosing a value that is very
|
|
39
|
+
close to 1 or 0 might lead to poor performance of the sampler.
|
|
40
|
+
|
|
41
|
+
callback : callable, *Optional*
|
|
42
|
+
If set this function will be called after every sample.
|
|
43
|
+
The signature of the callback function is
|
|
44
|
+
`callback(sample, sample_index)`,
|
|
45
|
+
where `sample` is the current sample and `sample_index` is the index of
|
|
46
|
+
the sample.
|
|
47
|
+
An example is shown in demos/demo31_callback.py.
|
|
48
|
+
|
|
49
|
+
Example
|
|
50
|
+
-------
|
|
51
|
+
.. code-block:: python
|
|
52
|
+
|
|
53
|
+
# Import cuqi
|
|
54
|
+
import cuqi
|
|
55
|
+
|
|
56
|
+
# Define a target distribution
|
|
57
|
+
tp = cuqi.testproblem.WangCubic()
|
|
58
|
+
target = tp.posterior
|
|
59
|
+
|
|
60
|
+
# Set up sampler
|
|
61
|
+
sampler = cuqi.experimental.mcmc.NUTSNew(target)
|
|
62
|
+
|
|
63
|
+
# Sample
|
|
64
|
+
sampler.warmup(5000)
|
|
65
|
+
sampler.sample(10000)
|
|
66
|
+
|
|
67
|
+
# Get samples
|
|
68
|
+
samples = sampler.get_samples()
|
|
69
|
+
|
|
70
|
+
# Plot samples
|
|
71
|
+
samples.plot_pair()
|
|
72
|
+
|
|
73
|
+
After running the NUTS sampler, run diagnostics can be accessed via the
|
|
74
|
+
following attributes:
|
|
75
|
+
|
|
76
|
+
.. code-block:: python
|
|
77
|
+
|
|
78
|
+
# Number of tree nodes created each NUTS iteration
|
|
79
|
+
sampler.num_tree_node_list
|
|
80
|
+
|
|
81
|
+
# Step size used in each NUTS iteration
|
|
82
|
+
sampler.epsilon_list
|
|
83
|
+
|
|
84
|
+
# Suggested step size during adaptation (the value of this step size is
|
|
85
|
+
# only used after adaptation).
|
|
86
|
+
sampler.epsilon_bar_list
|
|
87
|
+
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
_STATE_KEYS = SamplerNew._STATE_KEYS.union({'_epsilon', '_epsilon_bar',
|
|
91
|
+
'_H_bar', '_mu',
|
|
92
|
+
'_alpha', '_n_alpha'})
|
|
93
|
+
|
|
94
|
+
_HISTORY_KEYS = SamplerNew._HISTORY_KEYS.union({'num_tree_node_list',
|
|
95
|
+
'epsilon_list',
|
|
96
|
+
'epsilon_bar_list'})
|
|
97
|
+
|
|
98
|
+
def __init__(self, target, initial_point=None, max_depth=15,
|
|
99
|
+
step_size=None, opt_acc_rate=0.6, **kwargs):
|
|
100
|
+
super().__init__(target, initial_point=initial_point, **kwargs)
|
|
101
|
+
|
|
102
|
+
# Assign parameters as attributes
|
|
103
|
+
self.max_depth = max_depth
|
|
104
|
+
self.step_size = step_size
|
|
105
|
+
self.opt_acc_rate = opt_acc_rate
|
|
106
|
+
|
|
107
|
+
# Set current point
|
|
108
|
+
self.current_point = self.initial_point
|
|
109
|
+
|
|
110
|
+
# Initialize epsilon and epsilon_bar
|
|
111
|
+
# epsilon is the step size used in the current iteration
|
|
112
|
+
# after warm up and one sampling step, epsilon is updated
|
|
113
|
+
# to epsilon_bar for the remaining sampling steps.
|
|
114
|
+
self._epsilon = None
|
|
115
|
+
self._epsilon_bar = None
|
|
116
|
+
self._H_bar = None
|
|
117
|
+
|
|
118
|
+
# Arrays to store acceptance rate
|
|
119
|
+
self._acc = [None]
|
|
120
|
+
|
|
121
|
+
# NUTS run diagnostic:
|
|
122
|
+
# number of tree nodes created each NUTS iteration
|
|
123
|
+
self._num_tree_node = 0
|
|
124
|
+
# Create lists to store NUTS run diagnostics
|
|
125
|
+
self._create_run_diagnostic_attributes()
|
|
126
|
+
|
|
127
|
+
#=========================================================================
|
|
128
|
+
#============================== Properties ===============================
|
|
129
|
+
#=========================================================================
|
|
130
|
+
@property
|
|
131
|
+
def max_depth(self):
|
|
132
|
+
return self._max_depth
|
|
133
|
+
|
|
134
|
+
@max_depth.setter
|
|
135
|
+
def max_depth(self, value):
|
|
136
|
+
if not isinstance(value, int):
|
|
137
|
+
raise TypeError('max_depth must be an integer.')
|
|
138
|
+
if value < 0:
|
|
139
|
+
raise ValueError('max_depth must be >= 0.')
|
|
140
|
+
self._max_depth = value
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def step_size(self):
|
|
144
|
+
return self._step_size
|
|
145
|
+
|
|
146
|
+
@step_size.setter
|
|
147
|
+
def step_size(self, value):
|
|
148
|
+
if value is None:
|
|
149
|
+
pass # NUTS will adapt the step size
|
|
150
|
+
|
|
151
|
+
# step_size must be a positive float, raise error otherwise
|
|
152
|
+
elif isinstance(value, bool)\
|
|
153
|
+
or not isinstance(value, Number)\
|
|
154
|
+
or value <= 0:
|
|
155
|
+
raise TypeError('step_size must be a positive float or None.')
|
|
156
|
+
self._step_size = value
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def opt_acc_rate(self):
|
|
160
|
+
return self._opt_acc_rate
|
|
161
|
+
|
|
162
|
+
@opt_acc_rate.setter
|
|
163
|
+
def opt_acc_rate(self, value):
|
|
164
|
+
if not isinstance(value, Number) or value <= 0 or value >= 1:
|
|
165
|
+
raise ValueError('opt_acc_rate must be a float in (0, 1).')
|
|
166
|
+
self._opt_acc_rate = value
|
|
167
|
+
|
|
168
|
+
#=========================================================================
|
|
169
|
+
#================== Implement methods required by SamplerNew =============
|
|
170
|
+
#=========================================================================
|
|
171
|
+
def validate_target(self):
|
|
172
|
+
# Check if the target has logd and gradient methods
|
|
173
|
+
try:
|
|
174
|
+
current_target_logd, current_target_grad =\
|
|
175
|
+
self._nuts_target(np.ones(self.dim))
|
|
176
|
+
except:
|
|
177
|
+
raise ValueError('Target must have logd and gradient methods.')
|
|
178
|
+
|
|
179
|
+
def reset(self):
|
|
180
|
+
# Call the parent reset method
|
|
181
|
+
super().reset()
|
|
182
|
+
# Reset NUTS run diagnostic attributes
|
|
183
|
+
self._reset_run_diagnostic_attributes()
|
|
184
|
+
|
|
185
|
+
def step(self):
|
|
186
|
+
# Convert current_point, logd, and grad to numpy arrays
|
|
187
|
+
# if they are CUQIarray objects
|
|
188
|
+
if isinstance(self.current_point, CUQIarray):
|
|
189
|
+
self.current_point = self.current_point.to_numpy()
|
|
190
|
+
if isinstance(self.current_target_logd, CUQIarray):
|
|
191
|
+
self.current_target_logd = self.current_target_logd.to_numpy()
|
|
192
|
+
if isinstance(self.current_target_grad, CUQIarray):
|
|
193
|
+
self.current_target_grad = self.current_target_grad.to_numpy()
|
|
194
|
+
|
|
195
|
+
# reset number of tree nodes for each iteration
|
|
196
|
+
self._num_tree_node = 0
|
|
197
|
+
|
|
198
|
+
# copy current point, logd, and grad in local variables
|
|
199
|
+
point_k = self.current_point.copy() # initial position (parameters)
|
|
200
|
+
logd_k = self.current_target_logd
|
|
201
|
+
grad_k = self.current_target_grad.copy() # initial gradient
|
|
202
|
+
|
|
203
|
+
# compute r_k and Hamiltonian
|
|
204
|
+
r_k = self._Kfun(1, 'sample') # resample momentum vector
|
|
205
|
+
Ham = logd_k - self._Kfun(r_k, 'eval') # Hamiltonian
|
|
206
|
+
|
|
207
|
+
# slice variable
|
|
208
|
+
log_u = Ham - np.random.exponential(1, size=1)
|
|
209
|
+
|
|
210
|
+
# initialization
|
|
211
|
+
j, s, n = 0, 1, 1
|
|
212
|
+
point_minus, point_plus = np.copy(point_k), np.copy(point_k)
|
|
213
|
+
grad_minus, grad_plus = np.copy(grad_k), np.copy(grad_k)
|
|
214
|
+
r_minus, r_plus = np.copy(r_k), np.copy(r_k)
|
|
215
|
+
|
|
216
|
+
# run NUTS
|
|
217
|
+
while (s == 1) and (j <= self.max_depth):
|
|
218
|
+
# sample a direction
|
|
219
|
+
v = int(2*(np.random.rand() < 0.5)-1)
|
|
220
|
+
|
|
221
|
+
# build tree: doubling procedure
|
|
222
|
+
if (v == -1):
|
|
223
|
+
point_minus, r_minus, grad_minus, _, _, _, \
|
|
224
|
+
point_prime, logd_prime, grad_prime,\
|
|
225
|
+
n_prime, s_prime, alpha, n_alpha = \
|
|
226
|
+
self._BuildTree(point_minus, r_minus, grad_minus,
|
|
227
|
+
Ham, log_u, v, j, self._epsilon)
|
|
228
|
+
else:
|
|
229
|
+
_, _, _, point_plus, r_plus, grad_plus, \
|
|
230
|
+
point_prime, logd_prime, grad_prime,\
|
|
231
|
+
n_prime, s_prime, alpha, n_alpha = \
|
|
232
|
+
self._BuildTree(point_plus, r_plus, grad_plus,
|
|
233
|
+
Ham, log_u, v, j, self._epsilon)
|
|
234
|
+
|
|
235
|
+
# Metropolis step
|
|
236
|
+
alpha2 = min(1, (n_prime/n)) #min(0, np.log(n_p) - np.log(n))
|
|
237
|
+
if (s_prime == 1) and (np.random.rand() <= alpha2):
|
|
238
|
+
self.current_point = point_prime
|
|
239
|
+
self.current_target_logd = logd_prime
|
|
240
|
+
self.current_target_grad = np.copy(grad_prime)
|
|
241
|
+
self._acc.append(1)
|
|
242
|
+
else:
|
|
243
|
+
self._acc.append(0)
|
|
244
|
+
|
|
245
|
+
# update number of particles, tree level, and stopping criterion
|
|
246
|
+
n += n_prime
|
|
247
|
+
dpoints = point_plus - point_minus
|
|
248
|
+
s = s_prime *\
|
|
249
|
+
int((dpoints @ r_minus.T) >= 0) * int((dpoints @ r_plus.T) >= 0)
|
|
250
|
+
j += 1
|
|
251
|
+
self._alpha = alpha
|
|
252
|
+
self._n_alpha = n_alpha
|
|
253
|
+
|
|
254
|
+
# update run diagnostic attributes
|
|
255
|
+
self._update_run_diagnostic_attributes(
|
|
256
|
+
self._num_tree_node, self._epsilon, self._epsilon_bar)
|
|
257
|
+
|
|
258
|
+
self._epsilon = self._epsilon_bar
|
|
259
|
+
if np.isnan(self.current_target_logd):
|
|
260
|
+
raise NameError('NaN potential func')
|
|
261
|
+
|
|
262
|
+
def tune(self, skip_len, update_count):
|
|
263
|
+
""" adapt epsilon during burn-in using dual averaging"""
|
|
264
|
+
k = update_count+1
|
|
265
|
+
|
|
266
|
+
# Fixed parameters that do not change during the run
|
|
267
|
+
gamma, t_0, kappa = 0.05, 10, 0.75 # kappa in (0.5, 1]
|
|
268
|
+
|
|
269
|
+
eta1 = 1/(k + t_0)
|
|
270
|
+
self._H_bar = (1-eta1)*self._H_bar +\
|
|
271
|
+
eta1*(self.opt_acc_rate - (self._alpha/self._n_alpha))
|
|
272
|
+
self._epsilon = np.exp(self._mu - (np.sqrt(k)/gamma)*self._H_bar)
|
|
273
|
+
eta = k**(-kappa)
|
|
274
|
+
self._epsilon_bar =\
|
|
275
|
+
np.exp(eta*np.log(self._epsilon) +(1-eta)*np.log(self._epsilon_bar))
|
|
276
|
+
|
|
277
|
+
def _pre_warmup(self):
|
|
278
|
+
super()._pre_warmup()
|
|
279
|
+
|
|
280
|
+
self.current_target_logd, self.current_target_grad =\
|
|
281
|
+
self._nuts_target(self.current_point)
|
|
282
|
+
|
|
283
|
+
# Set up tuning parameters (only first time tuning is called)
|
|
284
|
+
# Note:
|
|
285
|
+
# Parameters changes during the tune run
|
|
286
|
+
# self._epsilon_bar
|
|
287
|
+
# self._H_bar
|
|
288
|
+
# self._epsilon
|
|
289
|
+
# Parameters that does not change during the run
|
|
290
|
+
# self._mu
|
|
291
|
+
|
|
292
|
+
if self._epsilon is None:
|
|
293
|
+
# parameters dual averaging
|
|
294
|
+
self._epsilon = self._FindGoodEpsilon()
|
|
295
|
+
# Parameter mu, does not change during the run
|
|
296
|
+
self._mu = np.log(10*self._epsilon)
|
|
297
|
+
|
|
298
|
+
if self._epsilon_bar is None: # Initial value of epsilon_bar
|
|
299
|
+
self._epsilon_bar = 1
|
|
300
|
+
|
|
301
|
+
if self._H_bar is None: # Initial value of H_bar
|
|
302
|
+
self._H_bar = 0
|
|
303
|
+
|
|
304
|
+
def _pre_sample(self):
|
|
305
|
+
super()._pre_sample()
|
|
306
|
+
|
|
307
|
+
self.current_target_logd, self.current_target_grad =\
|
|
308
|
+
self._nuts_target(self.current_point)
|
|
309
|
+
|
|
310
|
+
# Set up epsilon and epsilon_bar if not set
|
|
311
|
+
if self._epsilon is None:
|
|
312
|
+
if self.step_size is None:
|
|
313
|
+
step_size = self._FindGoodEpsilon()
|
|
314
|
+
else:
|
|
315
|
+
step_size = self.step_size
|
|
316
|
+
self._epsilon = step_size
|
|
317
|
+
self._epsilon_bar = step_size
|
|
318
|
+
|
|
319
|
+
#=========================================================================
|
|
320
|
+
def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
|
|
321
|
+
return self.target.logd(x), self.target.gradient(x)
|
|
322
|
+
|
|
323
|
+
#=========================================================================
|
|
324
|
+
# auxiliary standard Gaussian PDF: kinetic energy function
|
|
325
|
+
# d_log_2pi = d*np.log(2*np.pi)
|
|
326
|
+
def _Kfun(self, r, flag):
|
|
327
|
+
if flag == 'eval': # evaluate
|
|
328
|
+
return 0.5*(r.T @ r) #+ d_log_2pi
|
|
329
|
+
if flag == 'sample': # sample
|
|
330
|
+
return np.random.standard_normal(size=self.dim)
|
|
331
|
+
|
|
332
|
+
#=========================================================================
|
|
333
|
+
def _FindGoodEpsilon(self, epsilon=1):
|
|
334
|
+
point_k = self.current_point
|
|
335
|
+
self.current_target_logd, self.current_target_grad = self._nuts_target(
|
|
336
|
+
point_k)
|
|
337
|
+
logd = self.current_target_logd
|
|
338
|
+
grad = self.current_target_grad
|
|
339
|
+
|
|
340
|
+
r = self._Kfun(1, 'sample') # resample a momentum
|
|
341
|
+
Ham = logd - self._Kfun(r, 'eval') # initial Hamiltonian
|
|
342
|
+
_, r_prime, logd_prime, grad_prime = self._Leapfrog(
|
|
343
|
+
point_k, r, grad, epsilon)
|
|
344
|
+
|
|
345
|
+
# trick to make sure the step is not huge, leading to infinite values of
|
|
346
|
+
# the likelihood
|
|
347
|
+
k = 1
|
|
348
|
+
while np.isinf(logd_prime) or np.isinf(grad_prime).any():
|
|
349
|
+
k *= 0.5
|
|
350
|
+
_, r_prime, logd_prime, grad_prime = self._Leapfrog(
|
|
351
|
+
point_k, r, grad, epsilon*k)
|
|
352
|
+
epsilon = 0.5*k*epsilon
|
|
353
|
+
|
|
354
|
+
# doubles/halves the value of epsilon until the accprob of the Langevin
|
|
355
|
+
# proposal crosses 0.5
|
|
356
|
+
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
|
|
357
|
+
log_ratio = Ham_prime - Ham
|
|
358
|
+
a = 1 if log_ratio > np.log(0.5) else -1
|
|
359
|
+
while (a*log_ratio > -a*np.log(2)):
|
|
360
|
+
epsilon = (2**a)*epsilon
|
|
361
|
+
_, r_prime, logd_prime, _ = self._Leapfrog(
|
|
362
|
+
point_k, r, grad, epsilon)
|
|
363
|
+
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
|
|
364
|
+
log_ratio = Ham_prime - Ham
|
|
365
|
+
return epsilon
|
|
366
|
+
|
|
367
|
+
#=========================================================================
|
|
368
|
+
def _Leapfrog(self, point_old, r_old, grad_old, epsilon):
|
|
369
|
+
# symplectic integrator: trajectories preserve phase space volumen
|
|
370
|
+
r_new = r_old + 0.5*epsilon*grad_old # half-step
|
|
371
|
+
point_new = point_old + epsilon*r_new # full-step
|
|
372
|
+
logd_new, grad_new = self._nuts_target(point_new) # new gradient
|
|
373
|
+
r_new += 0.5*epsilon*grad_new # half-step
|
|
374
|
+
return point_new, r_new, logd_new, grad_new
|
|
375
|
+
|
|
376
|
+
#=========================================================================
|
|
377
|
+
def _BuildTree(
|
|
378
|
+
self, point_k, r, grad, Ham, log_u, v, j, epsilon, Delta_max=1000):
|
|
379
|
+
# Increment the number of tree nodes counter
|
|
380
|
+
self._num_tree_node += 1
|
|
381
|
+
|
|
382
|
+
if (j == 0): # base case
|
|
383
|
+
# single leapfrog step in the direction v
|
|
384
|
+
point_prime, r_prime, logd_prime, grad_prime = self._Leapfrog(
|
|
385
|
+
point_k, r, grad, v*epsilon)
|
|
386
|
+
Ham_prime = logd_prime - self._Kfun(r_prime, 'eval') # Hamiltonian
|
|
387
|
+
# eval
|
|
388
|
+
n_prime = int(log_u <= Ham_prime) # if particle is in the slice
|
|
389
|
+
s_prime = int(log_u < Delta_max + Ham_prime) # check U-turn
|
|
390
|
+
#
|
|
391
|
+
diff_Ham = Ham_prime - Ham
|
|
392
|
+
|
|
393
|
+
# Compute the acceptance probability
|
|
394
|
+
# alpha_prime = min(1, np.exp(diff_Ham))
|
|
395
|
+
# written in a stable way to avoid overflow when computing
|
|
396
|
+
# exp(diff_Ham) for large values of diff_Ham
|
|
397
|
+
alpha_prime = 1 if diff_Ham > 0 else np.exp(diff_Ham)
|
|
398
|
+
n_alpha_prime = 1
|
|
399
|
+
#
|
|
400
|
+
point_minus, point_plus = point_prime, point_prime
|
|
401
|
+
r_minus, r_plus = r_prime, r_prime
|
|
402
|
+
grad_minus, grad_plus = grad_prime, grad_prime
|
|
403
|
+
else:
|
|
404
|
+
# recursion: build the left/right subtrees
|
|
405
|
+
point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus, \
|
|
406
|
+
point_prime, logd_prime, grad_prime,\
|
|
407
|
+
n_prime, s_prime, alpha_prime, n_alpha_prime = \
|
|
408
|
+
self._BuildTree(point_k, r, grad,
|
|
409
|
+
Ham, log_u, v, j-1, epsilon)
|
|
410
|
+
if (s_prime == 1): # do only if the stopping criteria does not
|
|
411
|
+
# verify at the first subtree
|
|
412
|
+
if (v == -1):
|
|
413
|
+
point_minus, r_minus, grad_minus, _, _, _, \
|
|
414
|
+
point_2prime, logd_2prime, grad_2prime,\
|
|
415
|
+
n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
|
|
416
|
+
self._BuildTree(point_minus, r_minus, grad_minus,
|
|
417
|
+
Ham, log_u, v, j-1, epsilon)
|
|
418
|
+
else:
|
|
419
|
+
_, _, _, point_plus, r_plus, grad_plus, \
|
|
420
|
+
point_2prime, logd_2prime, grad_2prime,\
|
|
421
|
+
n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
|
|
422
|
+
self._BuildTree(point_plus, r_plus, grad_plus,
|
|
423
|
+
Ham, log_u, v, j-1, epsilon)
|
|
424
|
+
|
|
425
|
+
# Metropolis step
|
|
426
|
+
alpha2 = n_2prime / max(1, (n_prime + n_2prime))
|
|
427
|
+
if (np.random.rand() <= alpha2):
|
|
428
|
+
point_prime = np.copy(point_2prime)
|
|
429
|
+
logd_prime = np.copy(logd_2prime)
|
|
430
|
+
grad_prime = np.copy(grad_2prime)
|
|
431
|
+
|
|
432
|
+
# update number of particles and stopping criterion
|
|
433
|
+
alpha_prime += alpha_2prime
|
|
434
|
+
n_alpha_prime += n_alpha_2prime
|
|
435
|
+
dpoints = point_plus - point_minus
|
|
436
|
+
s_prime = s_2prime *\
|
|
437
|
+
int((dpoints@r_minus.T)>=0) * int((dpoints@r_plus.T)>=0)
|
|
438
|
+
n_prime += n_2prime
|
|
439
|
+
|
|
440
|
+
return point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus,\
|
|
441
|
+
point_prime, logd_prime, grad_prime,\
|
|
442
|
+
n_prime, s_prime, alpha_prime, n_alpha_prime
|
|
443
|
+
|
|
444
|
+
#=========================================================================
|
|
445
|
+
#======================== Diagnostic methods =============================
|
|
446
|
+
#=========================================================================
|
|
447
|
+
|
|
448
|
+
def _create_run_diagnostic_attributes(self):
|
|
449
|
+
"""A method to create attributes to store NUTS run diagnostic."""
|
|
450
|
+
self._reset_run_diagnostic_attributes()
|
|
451
|
+
|
|
452
|
+
def _reset_run_diagnostic_attributes(self):
|
|
453
|
+
"""A method to reset attributes to store NUTS run diagnostic."""
|
|
454
|
+
# List to store number of tree nodes created each NUTS iteration
|
|
455
|
+
self.num_tree_node_list = []
|
|
456
|
+
# List of step size used in each NUTS iteration
|
|
457
|
+
self.epsilon_list = []
|
|
458
|
+
# List of burn-in step size suggestion during adaptation
|
|
459
|
+
# only used when adaptation is done
|
|
460
|
+
# remains fixed after adaptation (after burn-in)
|
|
461
|
+
self.epsilon_bar_list = []
|
|
462
|
+
|
|
463
|
+
def _update_run_diagnostic_attributes(self, n_tree, eps, eps_bar):
|
|
464
|
+
"""A method to update attributes to store NUTS run diagnostic."""
|
|
465
|
+
# Store the number of tree nodes created in iteration k
|
|
466
|
+
self.num_tree_node_list.append(n_tree)
|
|
467
|
+
# Store the step size used in iteration k
|
|
468
|
+
self.epsilon_list.append(eps)
|
|
469
|
+
# Store the step size suggestion during adaptation in iteration k
|
|
470
|
+
self.epsilon_bar_list.append(eps_bar)
|
|
@@ -60,14 +60,17 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
60
60
|
A Deblur example can be found in demos/demo27_ULA.py
|
|
61
61
|
# TODO: update demo once sampler merged
|
|
62
62
|
"""
|
|
63
|
+
|
|
64
|
+
_STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale', 'current_target_grad'})
|
|
65
|
+
|
|
63
66
|
def __init__(self, target, scale=1.0, **kwargs):
|
|
64
67
|
|
|
65
68
|
super().__init__(target, **kwargs)
|
|
66
69
|
|
|
67
70
|
self.scale = scale
|
|
68
71
|
self.current_point = self.initial_point
|
|
69
|
-
self.
|
|
70
|
-
self.
|
|
72
|
+
self.current_target_logd = self.target.logd(self.current_point)
|
|
73
|
+
self.current_target_grad = self.target.gradient(self.current_point)
|
|
71
74
|
self._acc = [1] # TODO. Check if we need this
|
|
72
75
|
|
|
73
76
|
def validate_target(self):
|
|
@@ -99,15 +102,15 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
99
102
|
1 (accepted)
|
|
100
103
|
"""
|
|
101
104
|
self.current_point = x_star
|
|
102
|
-
self.
|
|
103
|
-
self.
|
|
105
|
+
self.current_target_logd = target_eval_star
|
|
106
|
+
self.current_target_grad = target_grad_star
|
|
104
107
|
acc = 1
|
|
105
108
|
return acc
|
|
106
109
|
|
|
107
110
|
def step(self):
|
|
108
111
|
# propose state
|
|
109
112
|
xi = cuqi.distribution.Normal(mean=np.zeros(self.dim), std=np.sqrt(self.scale)).sample()
|
|
110
|
-
x_star = self.current_point + 0.5*self.scale*self.
|
|
113
|
+
x_star = self.current_point + 0.5*self.scale*self.current_target_grad + xi
|
|
111
114
|
|
|
112
115
|
# evaluate target
|
|
113
116
|
target_eval_star, target_grad_star = self.target.logd(x_star), self.target.gradient(x_star)
|
|
@@ -120,26 +123,6 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
120
123
|
def tune(self, skip_len, update_count):
|
|
121
124
|
pass
|
|
122
125
|
|
|
123
|
-
def get_state(self):
|
|
124
|
-
if isinstance(self.current_point, CUQIarray):
|
|
125
|
-
self.current_point = self.current_point.to_numpy()
|
|
126
|
-
if isinstance(self.current_target_eval, CUQIarray):
|
|
127
|
-
self.current_target_eval = self.current_target_eval.to_numpy()
|
|
128
|
-
if isinstance(self.current_target_grad_eval, CUQIarray):
|
|
129
|
-
self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
|
|
130
|
-
return {'sampler_type': 'ULA', 'current_point': self.current_point, \
|
|
131
|
-
'current_target_eval': self.current_target_eval, \
|
|
132
|
-
'current_target_grad_eval': self.current_target_grad_eval, \
|
|
133
|
-
'scale': self.scale}
|
|
134
|
-
|
|
135
|
-
def set_state(self, state):
|
|
136
|
-
temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
|
|
137
|
-
self.current_point = temp
|
|
138
|
-
temp = CUQIarray(state['current_target_eval'] , geometry=self.target.geometry)
|
|
139
|
-
self.current_target_eval = temp
|
|
140
|
-
temp = CUQIarray(state['current_target_grad_eval'] , geometry=self.target.geometry)
|
|
141
|
-
self.current_target_grad_eval = temp
|
|
142
|
-
self.scale = state['scale']
|
|
143
126
|
|
|
144
127
|
class MALANew(ULANew): # Refactor to Proposal-based sampler?
|
|
145
128
|
""" Metropolis-adjusted Langevin algorithm (MALA) (Roberts and Tweedie, 1996)
|
|
@@ -219,9 +202,9 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
|
|
|
219
202
|
scaler
|
|
220
203
|
1 if accepted, 0 otherwise
|
|
221
204
|
"""
|
|
222
|
-
log_target_ratio = target_eval_star - self.
|
|
205
|
+
log_target_ratio = target_eval_star - self.current_target_logd
|
|
223
206
|
log_prop_ratio = self._log_proposal(self.current_point, x_star, target_grad_star) \
|
|
224
|
-
- self._log_proposal(x_star, self.current_point, self.
|
|
207
|
+
- self._log_proposal(x_star, self.current_point, self.current_target_grad)
|
|
225
208
|
log_alpha = min(0, log_target_ratio + log_prop_ratio)
|
|
226
209
|
|
|
227
210
|
# accept/reject with Metropolis
|
|
@@ -229,8 +212,8 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
|
|
|
229
212
|
log_u = np.log(np.random.rand())
|
|
230
213
|
if (log_u <= log_alpha) and (np.isnan(target_eval_star) == False):
|
|
231
214
|
self.current_point = x_star
|
|
232
|
-
self.
|
|
233
|
-
self.
|
|
215
|
+
self.current_target_logd = target_eval_star
|
|
216
|
+
self.current_target_grad = target_grad_star
|
|
234
217
|
acc = 1
|
|
235
218
|
return acc
|
|
236
219
|
|
|
@@ -241,15 +224,3 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
|
|
|
241
224
|
mu = theta_k + ((self.scale)/2)*g_logpi_k
|
|
242
225
|
misfit = theta_star - mu
|
|
243
226
|
return -0.5*((1/(self.scale))*(misfit.T @ misfit))
|
|
244
|
-
|
|
245
|
-
def get_state(self):
|
|
246
|
-
if isinstance(self.current_point, CUQIarray):
|
|
247
|
-
self.current_point = self.current_point.to_numpy()
|
|
248
|
-
if isinstance(self.current_target_eval, CUQIarray):
|
|
249
|
-
self.current_target_eval = self.current_target_eval.to_numpy()
|
|
250
|
-
if isinstance(self.current_target_grad_eval, CUQIarray):
|
|
251
|
-
self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
|
|
252
|
-
return {'sampler_type': 'MALA', 'current_point': self.current_point, \
|
|
253
|
-
'current_target_eval': self.current_target_eval, \
|
|
254
|
-
'current_target_grad_eval': self.current_target_grad_eval, \
|
|
255
|
-
'scale': self.scale}
|
cuqi/experimental/mcmc/_mh.py
CHANGED
|
@@ -23,6 +23,8 @@ class MHNew(ProposalBasedSamplerNew):
|
|
|
23
23
|
|
|
24
24
|
"""
|
|
25
25
|
|
|
26
|
+
_STATE_KEYS = ProposalBasedSamplerNew._STATE_KEYS.union({'scale', '_scale_temp'})
|
|
27
|
+
|
|
26
28
|
def __init__(self, target, proposal=None, scale=1, **kwargs):
|
|
27
29
|
super().__init__(target, proposal=proposal, scale=scale, **kwargs)
|
|
28
30
|
# Due to a bug? in old MH, we must keep track of this extra variable to match behavior.
|
|
@@ -54,7 +56,7 @@ class MHNew(ProposalBasedSamplerNew):
|
|
|
54
56
|
target_eval_star = self.target.logd(x_star)
|
|
55
57
|
|
|
56
58
|
# ratio and acceptance probability
|
|
57
|
-
ratio = target_eval_star - self.
|
|
59
|
+
ratio = target_eval_star - self.current_target_logd # proposal is symmetric
|
|
58
60
|
alpha = min(0, ratio)
|
|
59
61
|
|
|
60
62
|
# accept/reject
|
|
@@ -62,7 +64,7 @@ class MHNew(ProposalBasedSamplerNew):
|
|
|
62
64
|
acc = 0
|
|
63
65
|
if (u_theta <= alpha):
|
|
64
66
|
self.current_point = x_star
|
|
65
|
-
self.
|
|
67
|
+
self.current_target_logd = target_eval_star
|
|
66
68
|
acc = 1
|
|
67
69
|
|
|
68
70
|
return acc
|
|
@@ -79,13 +81,3 @@ class MHNew(ProposalBasedSamplerNew):
|
|
|
79
81
|
|
|
80
82
|
# update parameters
|
|
81
83
|
self.scale = min(self._scale_temp, 1)
|
|
82
|
-
|
|
83
|
-
def get_state(self):
|
|
84
|
-
return {'sampler_type': 'MH', 'current_point': self.current_point.to_numpy(), 'current_target': self.current_target.to_numpy(), 'scale': self.scale}
|
|
85
|
-
|
|
86
|
-
def set_state(self, state):
|
|
87
|
-
temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
|
|
88
|
-
self.current_point = temp
|
|
89
|
-
temp = CUQIarray(state['current_target'] , geometry=self.target.geometry)
|
|
90
|
-
self.current_target = temp
|
|
91
|
-
self.scale = state['scale']
|
cuqi/experimental/mcmc/_pcn.py
CHANGED
|
@@ -5,13 +5,15 @@ from cuqi.array import CUQIarray
|
|
|
5
5
|
|
|
6
6
|
class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
7
7
|
|
|
8
|
+
_STATE_KEYS = SamplerNew._STATE_KEYS.union({'scale', 'current_likelihood_logd'})
|
|
9
|
+
|
|
8
10
|
def __init__(self, target, scale=1.0, **kwargs):
|
|
9
11
|
|
|
10
12
|
super().__init__(target, **kwargs)
|
|
11
13
|
|
|
12
14
|
self.scale = scale
|
|
13
15
|
self.current_point = self.initial_point
|
|
14
|
-
self.
|
|
16
|
+
self.current_likelihood_logd = self._loglikelihood(self.current_point)
|
|
15
17
|
|
|
16
18
|
self._acc = [1] # TODO. Check if we need this
|
|
17
19
|
|
|
@@ -33,7 +35,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
33
35
|
loglike_eval_star = self._loglikelihood(x_star)
|
|
34
36
|
|
|
35
37
|
# ratio and acceptance probability
|
|
36
|
-
ratio = loglike_eval_star - self.
|
|
38
|
+
ratio = loglike_eval_star - self.current_likelihood_logd # proposal is symmetric
|
|
37
39
|
alpha = min(0, ratio)
|
|
38
40
|
|
|
39
41
|
# accept/reject
|
|
@@ -41,7 +43,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
41
43
|
u_theta = np.log(np.random.rand())
|
|
42
44
|
if (u_theta <= alpha):
|
|
43
45
|
self.current_point = x_star
|
|
44
|
-
self.
|
|
46
|
+
self.current_likelihood_logd = loglike_eval_star
|
|
45
47
|
acc = 1
|
|
46
48
|
|
|
47
49
|
return acc
|
|
@@ -87,15 +89,3 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
|
|
|
87
89
|
|
|
88
90
|
def tune(self, skip_len, update_count):
|
|
89
91
|
pass
|
|
90
|
-
|
|
91
|
-
def get_state(self):
|
|
92
|
-
return {'sampler_type': 'PCN', 'current_point': self.current_point.to_numpy(), \
|
|
93
|
-
'current_loglike_eval': self.current_loglike_eval.to_numpy(), \
|
|
94
|
-
'scale': self.scale}
|
|
95
|
-
|
|
96
|
-
def set_state(self, state):
|
|
97
|
-
temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
|
|
98
|
-
self.current_point = temp
|
|
99
|
-
temp = CUQIarray(state['current_loglike_eval'] , geometry=self.target.geometry)
|
|
100
|
-
self.current_loglike_eval = temp
|
|
101
|
-
self.scale = state['scale']
|
cuqi/experimental/mcmc/_rto.py
CHANGED
|
@@ -50,6 +50,7 @@ class LinearRTONew(SamplerNew):
|
|
|
50
50
|
|
|
51
51
|
if initial_point is None: #TODO: Replace later with a getter
|
|
52
52
|
self.initial_point = np.zeros(self.dim)
|
|
53
|
+
self._samples = [self.initial_point]
|
|
53
54
|
|
|
54
55
|
self.current_point = self.initial_point
|
|
55
56
|
self._acc = [1] # TODO. Check if we need this
|
|
@@ -188,12 +189,6 @@ class LinearRTONew(SamplerNew):
|
|
|
188
189
|
if not hasattr(self.prior, "sqrtprecTimesMean"):
|
|
189
190
|
raise TypeError("Prior must contain a sqrtprecTimesMean attribute")
|
|
190
191
|
|
|
191
|
-
def get_state(self): #TODO: LinearRTO only need initial_point for reproducibility?
|
|
192
|
-
return {'sampler_type': 'LinearRTO'}
|
|
193
|
-
|
|
194
|
-
def set_state(self, state): #TODO: LinearRTO only need initial_point for reproducibility?
|
|
195
|
-
pass
|
|
196
|
-
|
|
197
192
|
class RegularizedLinearRTONew(LinearRTONew):
|
|
198
193
|
"""
|
|
199
194
|
Regularized Linear RTO (Randomize-Then-Optimize) sampler.
|
|
@@ -272,4 +267,4 @@ class RegularizedLinearRTONew(LinearRTONew):
|
|
|
272
267
|
maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
|
|
273
268
|
self.current_point, _ = sim.solve()
|
|
274
269
|
acc = 1
|
|
275
|
-
return acc
|
|
270
|
+
return acc
|
|
@@ -21,6 +21,11 @@ class SamplerNew(ABC):
|
|
|
21
21
|
Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
|
|
22
22
|
|
|
23
23
|
"""
|
|
24
|
+
_STATE_KEYS = {'current_point'}
|
|
25
|
+
""" Set of keys for the state dictionary. """
|
|
26
|
+
|
|
27
|
+
_HISTORY_KEYS = {'_samples', '_acc'}
|
|
28
|
+
""" Set of keys for the history dictionary. """
|
|
24
29
|
|
|
25
30
|
def __init__(self, target: cuqi.density.Density, initial_point=None, callback=None):
|
|
26
31
|
""" Initializer for abstract base class for all samplers.
|
|
@@ -45,8 +50,10 @@ class SamplerNew(ABC):
|
|
|
45
50
|
# Choose initial point if not given
|
|
46
51
|
if initial_point is None:
|
|
47
52
|
initial_point = np.ones(self.dim)
|
|
53
|
+
|
|
54
|
+
self.initial_point = initial_point
|
|
48
55
|
|
|
49
|
-
self._samples = [initial_point]
|
|
56
|
+
self._samples = [initial_point] # Remove. See #324.
|
|
50
57
|
|
|
51
58
|
# ------------ Abstract methods to be implemented by subclasses ------------
|
|
52
59
|
|
|
@@ -65,29 +72,16 @@ class SamplerNew(ABC):
|
|
|
65
72
|
""" Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
|
|
66
73
|
pass
|
|
67
74
|
|
|
68
|
-
|
|
69
|
-
def
|
|
70
|
-
"""
|
|
75
|
+
# -- _pre_sample and _pre_warmup methods: can be overridden by subclasses --
|
|
76
|
+
def _pre_sample(self):
|
|
77
|
+
""" Any code that needs to be run before sampling. """
|
|
71
78
|
pass
|
|
72
79
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
""" Set the state of the sampler. """
|
|
80
|
+
def _pre_warmup(self):
|
|
81
|
+
""" Any code that needs to be run before warmup. """
|
|
76
82
|
pass
|
|
77
83
|
|
|
78
|
-
|
|
79
84
|
# ------------ Public attributes ------------
|
|
80
|
-
|
|
81
|
-
@property
|
|
82
|
-
def initial_point(self):
|
|
83
|
-
""" Return the initial point of the sampler. This is always the first sample. """
|
|
84
|
-
return self._samples[0]
|
|
85
|
-
|
|
86
|
-
@initial_point.setter
|
|
87
|
-
def initial_point(self, value):
|
|
88
|
-
""" Set the initial point of the sampler. """
|
|
89
|
-
self._samples[0] = value
|
|
90
|
-
|
|
91
85
|
@property
|
|
92
86
|
def dim(self):
|
|
93
87
|
""" Dimension of the target density. """
|
|
@@ -109,6 +103,15 @@ class SamplerNew(ABC):
|
|
|
109
103
|
self._target = value
|
|
110
104
|
self.validate_target()
|
|
111
105
|
|
|
106
|
+
@property
|
|
107
|
+
def current_point(self):
|
|
108
|
+
""" The current point of the sampler. """
|
|
109
|
+
return self._current_point
|
|
110
|
+
|
|
111
|
+
@current_point.setter
|
|
112
|
+
def current_point(self, value):
|
|
113
|
+
""" Set the current point of the sampler. """
|
|
114
|
+
self._current_point = value
|
|
112
115
|
|
|
113
116
|
# ------------ Public methods ------------
|
|
114
117
|
|
|
@@ -156,6 +159,9 @@ class SamplerNew(ABC):
|
|
|
156
159
|
if batch_size > 0:
|
|
157
160
|
batch_handler = _BatchHandler(batch_size, sample_path)
|
|
158
161
|
|
|
162
|
+
# Any code that needs to be run before sampling
|
|
163
|
+
self._pre_sample()
|
|
164
|
+
|
|
159
165
|
# Draw samples
|
|
160
166
|
for _ in progressbar( range(Ns) ):
|
|
161
167
|
|
|
@@ -191,6 +197,9 @@ class SamplerNew(ABC):
|
|
|
191
197
|
|
|
192
198
|
tune_interval = max(int(tune_freq * Nb), 1)
|
|
193
199
|
|
|
200
|
+
# Any code that needs to be run before warmup
|
|
201
|
+
self._pre_warmup()
|
|
202
|
+
|
|
194
203
|
# Draw warmup samples with tuning
|
|
195
204
|
for idx in progressbar(range(Nb)):
|
|
196
205
|
|
|
@@ -209,6 +218,94 @@ class SamplerNew(ABC):
|
|
|
209
218
|
self._call_callback(self.current_point, len(self._samples)-1)
|
|
210
219
|
|
|
211
220
|
return self
|
|
221
|
+
|
|
222
|
+
def get_state(self) -> dict:
|
|
223
|
+
""" Return the state of the sampler.
|
|
224
|
+
|
|
225
|
+
The state is used when checkpointing the sampler.
|
|
226
|
+
|
|
227
|
+
The state of the sampler is a dictionary with keys 'metadata' and 'state'.
|
|
228
|
+
The 'metadata' key contains information about the sampler type.
|
|
229
|
+
The 'state' key contains the state of the sampler.
|
|
230
|
+
|
|
231
|
+
For example, the state of a "MH" sampler could be:
|
|
232
|
+
|
|
233
|
+
state = {
|
|
234
|
+
'metadata': {
|
|
235
|
+
'sampler_type': 'MH'
|
|
236
|
+
},
|
|
237
|
+
'state': {
|
|
238
|
+
'current_point': np.array([...]),
|
|
239
|
+
'current_target_logd': -123.45,
|
|
240
|
+
'scale': 1.0,
|
|
241
|
+
...
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
"""
|
|
245
|
+
state = {
|
|
246
|
+
'metadata': {
|
|
247
|
+
'sampler_type': self.__class__.__name__
|
|
248
|
+
},
|
|
249
|
+
'state': {
|
|
250
|
+
key: getattr(self, key) for key in self._STATE_KEYS
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
return state
|
|
254
|
+
|
|
255
|
+
def set_state(self, state: dict):
|
|
256
|
+
""" Set the state of the sampler.
|
|
257
|
+
|
|
258
|
+
The state is used when loading the sampler from a checkpoint.
|
|
259
|
+
|
|
260
|
+
The state of the sampler is a dictionary with keys 'metadata' and 'state'.
|
|
261
|
+
|
|
262
|
+
For example, the state of a "MH" sampler could be:
|
|
263
|
+
|
|
264
|
+
state = {
|
|
265
|
+
'metadata': {
|
|
266
|
+
'sampler_type': 'MH'
|
|
267
|
+
},
|
|
268
|
+
'state': {
|
|
269
|
+
'current_point': np.array([...]),
|
|
270
|
+
'current_target_logd': -123.45,
|
|
271
|
+
'scale': 1.0,
|
|
272
|
+
...
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
"""
|
|
276
|
+
if state['metadata']['sampler_type'] != self.__class__.__name__:
|
|
277
|
+
raise ValueError(f"Sampler type in state dictionary ({state['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
|
|
278
|
+
|
|
279
|
+
for key, value in state['state'].items():
|
|
280
|
+
if key in self._STATE_KEYS:
|
|
281
|
+
setattr(self, key, value)
|
|
282
|
+
else:
|
|
283
|
+
raise ValueError(f"Key {key} not recognized in state dictionary of sampler {self.__class__.__name__}.")
|
|
284
|
+
|
|
285
|
+
def get_history(self) -> dict:
|
|
286
|
+
""" Return the history of the sampler. """
|
|
287
|
+
history = {
|
|
288
|
+
'metadata': {
|
|
289
|
+
'sampler_type': self.__class__.__name__
|
|
290
|
+
},
|
|
291
|
+
'history': {
|
|
292
|
+
key: getattr(self, key) for key in self._HISTORY_KEYS
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
return history
|
|
296
|
+
|
|
297
|
+
def set_history(self, history: dict):
|
|
298
|
+
""" Set the history of the sampler. """
|
|
299
|
+
if history['metadata']['sampler_type'] != self.__class__.__name__:
|
|
300
|
+
raise ValueError(f"Sampler type in history dictionary ({history['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
|
|
301
|
+
|
|
302
|
+
for key, value in history['history'].items():
|
|
303
|
+
if key in self._HISTORY_KEYS:
|
|
304
|
+
setattr(self, key, value)
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
|
|
307
|
+
|
|
308
|
+
# ------------ Private methods ------------
|
|
212
309
|
|
|
213
310
|
def _call_callback(self, sample, sample_index):
|
|
214
311
|
""" Calls the callback function. Assumes input is sample and sample index"""
|
|
@@ -218,6 +315,9 @@ class SamplerNew(ABC):
|
|
|
218
315
|
|
|
219
316
|
class ProposalBasedSamplerNew(SamplerNew, ABC):
|
|
220
317
|
""" Abstract base class for samplers that use a proposal distribution. """
|
|
318
|
+
|
|
319
|
+
_STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale'})
|
|
320
|
+
|
|
221
321
|
def __init__(self, target, proposal=None, scale=1, **kwargs):
|
|
222
322
|
""" Initializer for proposal based samplers.
|
|
223
323
|
|
|
@@ -240,7 +340,7 @@ class ProposalBasedSamplerNew(SamplerNew, ABC):
|
|
|
240
340
|
super().__init__(target, **kwargs)
|
|
241
341
|
|
|
242
342
|
self.current_point = self.initial_point
|
|
243
|
-
self.
|
|
343
|
+
self.current_target_logd = self.target.logd(self.current_point)
|
|
244
344
|
self.proposal = proposal
|
|
245
345
|
self.scale = scale
|
|
246
346
|
|
cuqi/sampler/_hmc.py
CHANGED
|
@@ -83,6 +83,9 @@ class NUTS(Sampler):
|
|
|
83
83
|
self.max_depth = max_depth
|
|
84
84
|
self.adapt_step_size = adapt_step_size
|
|
85
85
|
self.opt_acc_rate = opt_acc_rate
|
|
86
|
+
# if this flag is True, the samples and the burn-in will be returned
|
|
87
|
+
# otherwise, the burn-in will be truncated
|
|
88
|
+
self._return_burnin = False
|
|
86
89
|
|
|
87
90
|
# NUTS run diagnostic
|
|
88
91
|
# number of tree nodes created each NUTS iteration
|
|
@@ -226,9 +229,10 @@ class NUTS(Sampler):
|
|
|
226
229
|
if np.isnan(joint_eval[k]):
|
|
227
230
|
raise NameError('NaN potential func')
|
|
228
231
|
|
|
229
|
-
# apply burn-in
|
|
230
|
-
|
|
231
|
-
|
|
232
|
+
# apply burn-in
|
|
233
|
+
if not self._return_burnin:
|
|
234
|
+
theta = theta[:, Nb:]
|
|
235
|
+
joint_eval = joint_eval[Nb:]
|
|
232
236
|
return theta, joint_eval, step_sizes
|
|
233
237
|
|
|
234
238
|
#=========================================================================
|
|
File without changes
|
|
File without changes
|
|
File without changes
|