CUQIpy 1.0.0.post0.dev145__py3-none-any.whl → 1.0.0.post0.dev180__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev145
3
+ Version: 1.0.0.post0.dev180
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=Qc2vgfclx0EbaHAQZJpzQEPFGq1LO7Yj1wMffmN8Llo,510
3
+ cuqi/_version.py,sha256=nZ9Rz10NtUnQZt09kW9YMygjxV0G-qsipgAL3xBFeAU,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -32,13 +32,14 @@ cuqi/distribution/_normal.py,sha256=UeoTtGDT7YSf4ZNo2amlVF9K-YQpYbf8q76jcRJTVFw,
32
32
  cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7ig,5018
33
33
  cuqi/distribution/_uniform.py,sha256=7xJmCZH_LPhuGkwEDGh-_CTtzcWKrXMOxtTJUFb7Ydo,1607
34
34
  cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
35
- cuqi/experimental/mcmc/__init__.py,sha256=vVnohcm4EIUwbp1sr3LbB0BkXO8jyZsbiKMJmIgetYY,314
35
+ cuqi/experimental/mcmc/__init__.py,sha256=S4aXYpnO75HQcwDYfr1-ki8UvlenPDXxshES5avtBF0,340
36
36
  cuqi/experimental/mcmc/_cwmh.py,sha256=yRlTk5a1QYfH3JyCecfOOTeDf-4-tmJ3Tl2Bc3pyp1Y,7336
37
+ cuqi/experimental/mcmc/_hmc.py,sha256=qqAyoAajLE_JenYMgAbD3tknuEf75AJu-ufF69GKGk4,19384
37
38
  cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=MX48u3GYgCckB6Q5h5kXr_qdIaLQH2toOG5u29OY7gk,8245
38
39
  cuqi/experimental/mcmc/_mh.py,sha256=aIV1Ntq0EAq3QJ1_X-DbP7eDAL-d_Or7d3RUO-R48I4,3090
39
40
  cuqi/experimental/mcmc/_pcn.py,sha256=3M8zhQGQa53Gz04AkC8wJM61_5rIjGVnhPefi8m4dbY,3531
40
41
  cuqi/experimental/mcmc/_rto.py,sha256=jSPznr34XPfWM6LysWIiN4hE-vtyti3cHyvzy9ruykg,11349
41
- cuqi/experimental/mcmc/_sampler.py,sha256=hbZwUHcEZFSSVd2tICcp9FdcK9UKB_-izdM7w4xwijs,14408
42
+ cuqi/experimental/mcmc/_sampler.py,sha256=_5Uo2F-Mta46w3lo7WBVNwvTLYhES_BzMTJrKxA00c8,14861
42
43
  cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
43
44
  cuqi/geometry/_geometry.py,sha256=WYFC-4_VBTW73b2ldsnfGYKvdSiCE8plr89xTSmkadg,46804
44
45
  cuqi/implicitprior/__init__.py,sha256=ZRZ9fgxgEl5n0A9F7WCl1_jid-GUiC8ZLkyTmGQmFlY,100
@@ -59,7 +60,7 @@ cuqi/sampler/_conjugate.py,sha256=ztmUR3V3qZk9zelKx48ULnmMs_zKTDUfohc256VOIe8,27
59
60
  cuqi/sampler/_conjugate_approx.py,sha256=xX-X71EgxGnZooOY6CIBhuJTs3dhcKfoLnoFxX3CO2g,1938
60
61
  cuqi/sampler/_cwmh.py,sha256=VlAVT1SXQU0yD5ZeR-_ckWvX-ifJrMweFFdFbxdfB_k,7775
61
62
  cuqi/sampler/_gibbs.py,sha256=N7qcePwMkRtxINN5JF0FaMIdDCXZGqsfKjfha_KHCck,8627
62
- cuqi/sampler/_hmc.py,sha256=76nPkvNU0wLSg4qvm-1s048MzQasl5Qk94sHpyeJ5hM,14819
63
+ cuqi/sampler/_hmc.py,sha256=EUTefZir-wapoZ7OZFb5M5vayL8z6XksZRMY1BpbuXc,15027
63
64
  cuqi/sampler/_langevin_algorithm.py,sha256=o5EyvaR6QGAD7LKwXVRC3WwAP5IYJf5GoMVWl9DrfOA,7861
64
65
  cuqi/sampler/_laplace_approximation.py,sha256=u018Z5eqlcq_cIwD9yNOaA15dLQE_vUWaee5Xp8bcjg,6454
65
66
  cuqi/sampler/_mh.py,sha256=V5tIdn-KdfWo4J_Nbf-AH6XwKWblWUyc4BeuSikUHsE,7062
@@ -75,8 +76,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
75
76
  cuqi/utilities/__init__.py,sha256=EfxHLdsyDNugbmbzs43nV_AeKcycM9sVBjG9WZydagA,351
76
77
  cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
77
78
  cuqi/utilities/_utilities.py,sha256=At3DOXRdF3GwLkVcM2FXooGyjAGfPkIM0bRzhTfLmWk,8046
78
- CUQIpy-1.0.0.post0.dev145.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
79
- CUQIpy-1.0.0.post0.dev145.dist-info/METADATA,sha256=REboBBeNZ3O4SrH9hNrNAYgvQPU9XP49ezl4bHyCigQ,18393
80
- CUQIpy-1.0.0.post0.dev145.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
81
- CUQIpy-1.0.0.post0.dev145.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
82
- CUQIpy-1.0.0.post0.dev145.dist-info/RECORD,,
79
+ CUQIpy-1.0.0.post0.dev180.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
80
+ CUQIpy-1.0.0.post0.dev180.dist-info/METADATA,sha256=YSVheKM46Y6Al3zXEu11wlCkJxAasZw2onXB0nTLN08,18393
81
+ CUQIpy-1.0.0.post0.dev180.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ CUQIpy-1.0.0.post0.dev180.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
83
+ CUQIpy-1.0.0.post0.dev180.dist-info/RECORD,,
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-04-15T12:29:37+0200",
11
+ "date": "2024-04-22T18:01:04+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "744dccfaf90c7e50a11b7c5a4a59d5438c37525d",
15
- "version": "1.0.0.post0.dev145"
14
+ "full-revisionid": "be4b485322b1be52f78dfe6d03694d6ba2000b11",
15
+ "version": "1.0.0.post0.dev180"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -6,3 +6,4 @@ from ._mh import MHNew
6
6
  from ._pcn import pCNNew
7
7
  from ._rto import LinearRTONew, RegularizedLinearRTONew
8
8
  from ._cwmh import CWMHNew
9
+ from ._hmc import NUTSNew
@@ -0,0 +1,470 @@
1
+ import numpy as np
2
+ import numpy as np
3
+ from cuqi.experimental.mcmc import SamplerNew
4
+ from cuqi.array import CUQIarray
5
+ from numbers import Number
6
+
7
+ class NUTSNew(SamplerNew):
8
+ """No-U-Turn Sampler (Hoffman and Gelman, 2014).
9
+
10
+ Samples a distribution given its logpdf and gradient using a Hamiltonian
11
+ Monte Carlo (HMC) algorithm with automatic parameter tuning.
12
+
13
+ For more details see: See Hoffman, M. D., & Gelman, A. (2014). The no-U-turn
14
+ sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. Journal
15
+ of Machine Learning Research, 15, 1593-1623.
16
+
17
+ Parameters
18
+ ----------
19
+ target : `cuqi.distribution.Distribution`
20
+ The target distribution to sample. Must have logpdf and gradient method.
21
+ Custom logpdfs and gradients are supported by using a
22
+ :class:`cuqi.distribution.UserDefinedDistribution`.
23
+
24
+ initial_point : ndarray
25
+ Initial parameters. *Optional*. If not provided, the initial point is
26
+ an array of ones.
27
+
28
+ max_depth : int
29
+ Maximum depth of the tree >=0 and the default is 15.
30
+
31
+ step_size : None or float
32
+ If step_size is provided (as positive float), it will be used as initial
33
+ step size. If None, the step size will be estimated by the sampler.
34
+
35
+ opt_acc_rate : float
36
+ The optimal acceptance rate to reach if using adaptive step size.
37
+ Suggested values are 0.6 (default) or 0.8 (as in stan). In principle,
38
+ opt_acc_rate should be in (0, 1), however, choosing a value that is very
39
+ close to 1 or 0 might lead to poor performance of the sampler.
40
+
41
+ callback : callable, *Optional*
42
+ If set this function will be called after every sample.
43
+ The signature of the callback function is
44
+ `callback(sample, sample_index)`,
45
+ where `sample` is the current sample and `sample_index` is the index of
46
+ the sample.
47
+ An example is shown in demos/demo31_callback.py.
48
+
49
+ Example
50
+ -------
51
+ .. code-block:: python
52
+
53
+ # Import cuqi
54
+ import cuqi
55
+
56
+ # Define a target distribution
57
+ tp = cuqi.testproblem.WangCubic()
58
+ target = tp.posterior
59
+
60
+ # Set up sampler
61
+ sampler = cuqi.experimental.mcmc.NUTSNew(target)
62
+
63
+ # Sample
64
+ sampler.warmup(5000)
65
+ sampler.sample(10000)
66
+
67
+ # Get samples
68
+ samples = sampler.get_samples()
69
+
70
+ # Plot samples
71
+ samples.plot_pair()
72
+
73
+ After running the NUTS sampler, run diagnostics can be accessed via the
74
+ following attributes:
75
+
76
+ .. code-block:: python
77
+
78
+ # Number of tree nodes created each NUTS iteration
79
+ sampler.num_tree_node_list
80
+
81
+ # Step size used in each NUTS iteration
82
+ sampler.epsilon_list
83
+
84
+ # Suggested step size during adaptation (the value of this step size is
85
+ # only used after adaptation).
86
+ sampler.epsilon_bar_list
87
+
88
+ """
89
+
90
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'_epsilon', '_epsilon_bar',
91
+ '_H_bar', '_mu',
92
+ '_alpha', '_n_alpha'})
93
+
94
+ _HISTORY_KEYS = SamplerNew._HISTORY_KEYS.union({'num_tree_node_list',
95
+ 'epsilon_list',
96
+ 'epsilon_bar_list'})
97
+
98
+ def __init__(self, target, initial_point=None, max_depth=15,
99
+ step_size=None, opt_acc_rate=0.6, **kwargs):
100
+ super().__init__(target, initial_point=initial_point, **kwargs)
101
+
102
+ # Assign parameters as attributes
103
+ self.max_depth = max_depth
104
+ self.step_size = step_size
105
+ self.opt_acc_rate = opt_acc_rate
106
+
107
+ # Set current point
108
+ self.current_point = self.initial_point
109
+
110
+ # Initialize epsilon and epsilon_bar
111
+ # epsilon is the step size used in the current iteration
112
+ # after warm up and one sampling step, epsilon is updated
113
+ # to epsilon_bar for the remaining sampling steps.
114
+ self._epsilon = None
115
+ self._epsilon_bar = None
116
+ self._H_bar = None
117
+
118
+ # Arrays to store acceptance rate
119
+ self._acc = [None]
120
+
121
+ # NUTS run diagnostic:
122
+ # number of tree nodes created each NUTS iteration
123
+ self._num_tree_node = 0
124
+ # Create lists to store NUTS run diagnostics
125
+ self._create_run_diagnostic_attributes()
126
+
127
+ #=========================================================================
128
+ #============================== Properties ===============================
129
+ #=========================================================================
130
+ @property
131
+ def max_depth(self):
132
+ return self._max_depth
133
+
134
+ @max_depth.setter
135
+ def max_depth(self, value):
136
+ if not isinstance(value, int):
137
+ raise TypeError('max_depth must be an integer.')
138
+ if value < 0:
139
+ raise ValueError('max_depth must be >= 0.')
140
+ self._max_depth = value
141
+
142
+ @property
143
+ def step_size(self):
144
+ return self._step_size
145
+
146
+ @step_size.setter
147
+ def step_size(self, value):
148
+ if value is None:
149
+ pass # NUTS will adapt the step size
150
+
151
+ # step_size must be a positive float, raise error otherwise
152
+ elif isinstance(value, bool)\
153
+ or not isinstance(value, Number)\
154
+ or value <= 0:
155
+ raise TypeError('step_size must be a positive float or None.')
156
+ self._step_size = value
157
+
158
+ @property
159
+ def opt_acc_rate(self):
160
+ return self._opt_acc_rate
161
+
162
+ @opt_acc_rate.setter
163
+ def opt_acc_rate(self, value):
164
+ if not isinstance(value, Number) or value <= 0 or value >= 1:
165
+ raise ValueError('opt_acc_rate must be a float in (0, 1).')
166
+ self._opt_acc_rate = value
167
+
168
+ #=========================================================================
169
+ #================== Implement methods required by SamplerNew =============
170
+ #=========================================================================
171
+ def validate_target(self):
172
+ # Check if the target has logd and gradient methods
173
+ try:
174
+ current_target_logd, current_target_grad =\
175
+ self._nuts_target(np.ones(self.dim))
176
+ except:
177
+ raise ValueError('Target must have logd and gradient methods.')
178
+
179
+ def reset(self):
180
+ # Call the parent reset method
181
+ super().reset()
182
+ # Reset NUTS run diagnostic attributes
183
+ self._reset_run_diagnostic_attributes()
184
+
185
+ def step(self):
186
+ # Convert current_point, logd, and grad to numpy arrays
187
+ # if they are CUQIarray objects
188
+ if isinstance(self.current_point, CUQIarray):
189
+ self.current_point = self.current_point.to_numpy()
190
+ if isinstance(self.current_target_logd, CUQIarray):
191
+ self.current_target_logd = self.current_target_logd.to_numpy()
192
+ if isinstance(self.current_target_grad, CUQIarray):
193
+ self.current_target_grad = self.current_target_grad.to_numpy()
194
+
195
+ # reset number of tree nodes for each iteration
196
+ self._num_tree_node = 0
197
+
198
+ # copy current point, logd, and grad in local variables
199
+ point_k = self.current_point.copy() # initial position (parameters)
200
+ logd_k = self.current_target_logd
201
+ grad_k = self.current_target_grad.copy() # initial gradient
202
+
203
+ # compute r_k and Hamiltonian
204
+ r_k = self._Kfun(1, 'sample') # resample momentum vector
205
+ Ham = logd_k - self._Kfun(r_k, 'eval') # Hamiltonian
206
+
207
+ # slice variable
208
+ log_u = Ham - np.random.exponential(1, size=1)
209
+
210
+ # initialization
211
+ j, s, n = 0, 1, 1
212
+ point_minus, point_plus = np.copy(point_k), np.copy(point_k)
213
+ grad_minus, grad_plus = np.copy(grad_k), np.copy(grad_k)
214
+ r_minus, r_plus = np.copy(r_k), np.copy(r_k)
215
+
216
+ # run NUTS
217
+ while (s == 1) and (j <= self.max_depth):
218
+ # sample a direction
219
+ v = int(2*(np.random.rand() < 0.5)-1)
220
+
221
+ # build tree: doubling procedure
222
+ if (v == -1):
223
+ point_minus, r_minus, grad_minus, _, _, _, \
224
+ point_prime, logd_prime, grad_prime,\
225
+ n_prime, s_prime, alpha, n_alpha = \
226
+ self._BuildTree(point_minus, r_minus, grad_minus,
227
+ Ham, log_u, v, j, self._epsilon)
228
+ else:
229
+ _, _, _, point_plus, r_plus, grad_plus, \
230
+ point_prime, logd_prime, grad_prime,\
231
+ n_prime, s_prime, alpha, n_alpha = \
232
+ self._BuildTree(point_plus, r_plus, grad_plus,
233
+ Ham, log_u, v, j, self._epsilon)
234
+
235
+ # Metropolis step
236
+ alpha2 = min(1, (n_prime/n)) #min(0, np.log(n_p) - np.log(n))
237
+ if (s_prime == 1) and (np.random.rand() <= alpha2):
238
+ self.current_point = point_prime
239
+ self.current_target_logd = logd_prime
240
+ self.current_target_grad = np.copy(grad_prime)
241
+ self._acc.append(1)
242
+ else:
243
+ self._acc.append(0)
244
+
245
+ # update number of particles, tree level, and stopping criterion
246
+ n += n_prime
247
+ dpoints = point_plus - point_minus
248
+ s = s_prime *\
249
+ int((dpoints @ r_minus.T) >= 0) * int((dpoints @ r_plus.T) >= 0)
250
+ j += 1
251
+ self._alpha = alpha
252
+ self._n_alpha = n_alpha
253
+
254
+ # update run diagnostic attributes
255
+ self._update_run_diagnostic_attributes(
256
+ self._num_tree_node, self._epsilon, self._epsilon_bar)
257
+
258
+ self._epsilon = self._epsilon_bar
259
+ if np.isnan(self.current_target_logd):
260
+ raise NameError('NaN potential func')
261
+
262
+ def tune(self, skip_len, update_count):
263
+ """ adapt epsilon during burn-in using dual averaging"""
264
+ k = update_count+1
265
+
266
+ # Fixed parameters that do not change during the run
267
+ gamma, t_0, kappa = 0.05, 10, 0.75 # kappa in (0.5, 1]
268
+
269
+ eta1 = 1/(k + t_0)
270
+ self._H_bar = (1-eta1)*self._H_bar +\
271
+ eta1*(self.opt_acc_rate - (self._alpha/self._n_alpha))
272
+ self._epsilon = np.exp(self._mu - (np.sqrt(k)/gamma)*self._H_bar)
273
+ eta = k**(-kappa)
274
+ self._epsilon_bar =\
275
+ np.exp(eta*np.log(self._epsilon) +(1-eta)*np.log(self._epsilon_bar))
276
+
277
+ def _pre_warmup(self):
278
+ super()._pre_warmup()
279
+
280
+ self.current_target_logd, self.current_target_grad =\
281
+ self._nuts_target(self.current_point)
282
+
283
+ # Set up tuning parameters (only first time tuning is called)
284
+ # Note:
285
+ # Parameters changes during the tune run
286
+ # self._epsilon_bar
287
+ # self._H_bar
288
+ # self._epsilon
289
+ # Parameters that does not change during the run
290
+ # self._mu
291
+
292
+ if self._epsilon is None:
293
+ # parameters dual averaging
294
+ self._epsilon = self._FindGoodEpsilon()
295
+ # Parameter mu, does not change during the run
296
+ self._mu = np.log(10*self._epsilon)
297
+
298
+ if self._epsilon_bar is None: # Initial value of epsilon_bar
299
+ self._epsilon_bar = 1
300
+
301
+ if self._H_bar is None: # Initial value of H_bar
302
+ self._H_bar = 0
303
+
304
+ def _pre_sample(self):
305
+ super()._pre_sample()
306
+
307
+ self.current_target_logd, self.current_target_grad =\
308
+ self._nuts_target(self.current_point)
309
+
310
+ # Set up epsilon and epsilon_bar if not set
311
+ if self._epsilon is None:
312
+ if self.step_size is None:
313
+ step_size = self._FindGoodEpsilon()
314
+ else:
315
+ step_size = self.step_size
316
+ self._epsilon = step_size
317
+ self._epsilon_bar = step_size
318
+
319
+ #=========================================================================
320
+ def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
321
+ return self.target.logd(x), self.target.gradient(x)
322
+
323
+ #=========================================================================
324
+ # auxiliary standard Gaussian PDF: kinetic energy function
325
+ # d_log_2pi = d*np.log(2*np.pi)
326
+ def _Kfun(self, r, flag):
327
+ if flag == 'eval': # evaluate
328
+ return 0.5*(r.T @ r) #+ d_log_2pi
329
+ if flag == 'sample': # sample
330
+ return np.random.standard_normal(size=self.dim)
331
+
332
+ #=========================================================================
333
+ def _FindGoodEpsilon(self, epsilon=1):
334
+ point_k = self.current_point
335
+ self.current_target_logd, self.current_target_grad = self._nuts_target(
336
+ point_k)
337
+ logd = self.current_target_logd
338
+ grad = self.current_target_grad
339
+
340
+ r = self._Kfun(1, 'sample') # resample a momentum
341
+ Ham = logd - self._Kfun(r, 'eval') # initial Hamiltonian
342
+ _, r_prime, logd_prime, grad_prime = self._Leapfrog(
343
+ point_k, r, grad, epsilon)
344
+
345
+ # trick to make sure the step is not huge, leading to infinite values of
346
+ # the likelihood
347
+ k = 1
348
+ while np.isinf(logd_prime) or np.isinf(grad_prime).any():
349
+ k *= 0.5
350
+ _, r_prime, logd_prime, grad_prime = self._Leapfrog(
351
+ point_k, r, grad, epsilon*k)
352
+ epsilon = 0.5*k*epsilon
353
+
354
+ # doubles/halves the value of epsilon until the accprob of the Langevin
355
+ # proposal crosses 0.5
356
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
357
+ log_ratio = Ham_prime - Ham
358
+ a = 1 if log_ratio > np.log(0.5) else -1
359
+ while (a*log_ratio > -a*np.log(2)):
360
+ epsilon = (2**a)*epsilon
361
+ _, r_prime, logd_prime, _ = self._Leapfrog(
362
+ point_k, r, grad, epsilon)
363
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
364
+ log_ratio = Ham_prime - Ham
365
+ return epsilon
366
+
367
+ #=========================================================================
368
+ def _Leapfrog(self, point_old, r_old, grad_old, epsilon):
369
+ # symplectic integrator: trajectories preserve phase space volumen
370
+ r_new = r_old + 0.5*epsilon*grad_old # half-step
371
+ point_new = point_old + epsilon*r_new # full-step
372
+ logd_new, grad_new = self._nuts_target(point_new) # new gradient
373
+ r_new += 0.5*epsilon*grad_new # half-step
374
+ return point_new, r_new, logd_new, grad_new
375
+
376
+ #=========================================================================
377
+ def _BuildTree(
378
+ self, point_k, r, grad, Ham, log_u, v, j, epsilon, Delta_max=1000):
379
+ # Increment the number of tree nodes counter
380
+ self._num_tree_node += 1
381
+
382
+ if (j == 0): # base case
383
+ # single leapfrog step in the direction v
384
+ point_prime, r_prime, logd_prime, grad_prime = self._Leapfrog(
385
+ point_k, r, grad, v*epsilon)
386
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval') # Hamiltonian
387
+ # eval
388
+ n_prime = int(log_u <= Ham_prime) # if particle is in the slice
389
+ s_prime = int(log_u < Delta_max + Ham_prime) # check U-turn
390
+ #
391
+ diff_Ham = Ham_prime - Ham
392
+
393
+ # Compute the acceptance probability
394
+ # alpha_prime = min(1, np.exp(diff_Ham))
395
+ # written in a stable way to avoid overflow when computing
396
+ # exp(diff_Ham) for large values of diff_Ham
397
+ alpha_prime = 1 if diff_Ham > 0 else np.exp(diff_Ham)
398
+ n_alpha_prime = 1
399
+ #
400
+ point_minus, point_plus = point_prime, point_prime
401
+ r_minus, r_plus = r_prime, r_prime
402
+ grad_minus, grad_plus = grad_prime, grad_prime
403
+ else:
404
+ # recursion: build the left/right subtrees
405
+ point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus, \
406
+ point_prime, logd_prime, grad_prime,\
407
+ n_prime, s_prime, alpha_prime, n_alpha_prime = \
408
+ self._BuildTree(point_k, r, grad,
409
+ Ham, log_u, v, j-1, epsilon)
410
+ if (s_prime == 1): # do only if the stopping criteria does not
411
+ # verify at the first subtree
412
+ if (v == -1):
413
+ point_minus, r_minus, grad_minus, _, _, _, \
414
+ point_2prime, logd_2prime, grad_2prime,\
415
+ n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
416
+ self._BuildTree(point_minus, r_minus, grad_minus,
417
+ Ham, log_u, v, j-1, epsilon)
418
+ else:
419
+ _, _, _, point_plus, r_plus, grad_plus, \
420
+ point_2prime, logd_2prime, grad_2prime,\
421
+ n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
422
+ self._BuildTree(point_plus, r_plus, grad_plus,
423
+ Ham, log_u, v, j-1, epsilon)
424
+
425
+ # Metropolis step
426
+ alpha2 = n_2prime / max(1, (n_prime + n_2prime))
427
+ if (np.random.rand() <= alpha2):
428
+ point_prime = np.copy(point_2prime)
429
+ logd_prime = np.copy(logd_2prime)
430
+ grad_prime = np.copy(grad_2prime)
431
+
432
+ # update number of particles and stopping criterion
433
+ alpha_prime += alpha_2prime
434
+ n_alpha_prime += n_alpha_2prime
435
+ dpoints = point_plus - point_minus
436
+ s_prime = s_2prime *\
437
+ int((dpoints@r_minus.T)>=0) * int((dpoints@r_plus.T)>=0)
438
+ n_prime += n_2prime
439
+
440
+ return point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus,\
441
+ point_prime, logd_prime, grad_prime,\
442
+ n_prime, s_prime, alpha_prime, n_alpha_prime
443
+
444
+ #=========================================================================
445
+ #======================== Diagnostic methods =============================
446
+ #=========================================================================
447
+
448
+ def _create_run_diagnostic_attributes(self):
449
+ """A method to create attributes to store NUTS run diagnostic."""
450
+ self._reset_run_diagnostic_attributes()
451
+
452
+ def _reset_run_diagnostic_attributes(self):
453
+ """A method to reset attributes to store NUTS run diagnostic."""
454
+ # List to store number of tree nodes created each NUTS iteration
455
+ self.num_tree_node_list = []
456
+ # List of step size used in each NUTS iteration
457
+ self.epsilon_list = []
458
+ # List of burn-in step size suggestion during adaptation
459
+ # only used when adaptation is done
460
+ # remains fixed after adaptation (after burn-in)
461
+ self.epsilon_bar_list = []
462
+
463
+ def _update_run_diagnostic_attributes(self, n_tree, eps, eps_bar):
464
+ """A method to update attributes to store NUTS run diagnostic."""
465
+ # Store the number of tree nodes created in iteration k
466
+ self.num_tree_node_list.append(n_tree)
467
+ # Store the step size used in iteration k
468
+ self.epsilon_list.append(eps)
469
+ # Store the step size suggestion during adaptation in iteration k
470
+ self.epsilon_bar_list.append(eps_bar)
@@ -72,6 +72,15 @@ class SamplerNew(ABC):
72
72
  """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
73
73
  pass
74
74
 
75
+ # -- _pre_sample and _pre_warmup methods: can be overridden by subclasses --
76
+ def _pre_sample(self):
77
+ """ Any code that needs to be run before sampling. """
78
+ pass
79
+
80
+ def _pre_warmup(self):
81
+ """ Any code that needs to be run before warmup. """
82
+ pass
83
+
75
84
  # ------------ Public attributes ------------
76
85
  @property
77
86
  def dim(self):
@@ -150,6 +159,9 @@ class SamplerNew(ABC):
150
159
  if batch_size > 0:
151
160
  batch_handler = _BatchHandler(batch_size, sample_path)
152
161
 
162
+ # Any code that needs to be run before sampling
163
+ self._pre_sample()
164
+
153
165
  # Draw samples
154
166
  for _ in progressbar( range(Ns) ):
155
167
 
@@ -185,6 +197,9 @@ class SamplerNew(ABC):
185
197
 
186
198
  tune_interval = max(int(tune_freq * Nb), 1)
187
199
 
200
+ # Any code that needs to be run before warmup
201
+ self._pre_warmup()
202
+
188
203
  # Draw warmup samples with tuning
189
204
  for idx in progressbar(range(Nb)):
190
205
 
cuqi/sampler/_hmc.py CHANGED
@@ -83,6 +83,9 @@ class NUTS(Sampler):
83
83
  self.max_depth = max_depth
84
84
  self.adapt_step_size = adapt_step_size
85
85
  self.opt_acc_rate = opt_acc_rate
86
+ # if this flag is True, the samples and the burn-in will be returned
87
+ # otherwise, the burn-in will be truncated
88
+ self._return_burnin = False
86
89
 
87
90
  # NUTS run diagnostic
88
91
  # number of tree nodes created each NUTS iteration
@@ -226,9 +229,10 @@ class NUTS(Sampler):
226
229
  if np.isnan(joint_eval[k]):
227
230
  raise NameError('NaN potential func')
228
231
 
229
- # apply burn-in
230
- theta = theta[:, Nb:]
231
- joint_eval = joint_eval[Nb:]
232
+ # apply burn-in
233
+ if not self._return_burnin:
234
+ theta = theta[:, Nb:]
235
+ joint_eval = joint_eval[Nb:]
232
236
  return theta, joint_eval, step_sizes
233
237
 
234
238
  #=========================================================================