CUQIpy 1.3.0.post0.dev401__py3-none-any.whl → 1.4.0.post0.dev41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

Files changed (50) hide show
  1. cuqi/__init__.py +1 -0
  2. cuqi/_version.py +3 -3
  3. cuqi/density/_density.py +9 -1
  4. cuqi/distribution/_joint_distribution.py +96 -11
  5. cuqi/experimental/__init__.py +1 -2
  6. cuqi/experimental/_recommender.py +4 -4
  7. cuqi/legacy/__init__.py +2 -0
  8. cuqi/legacy/sampler/__init__.py +11 -0
  9. cuqi/legacy/sampler/_conjugate.py +55 -0
  10. cuqi/legacy/sampler/_conjugate_approx.py +52 -0
  11. cuqi/legacy/sampler/_cwmh.py +196 -0
  12. cuqi/legacy/sampler/_gibbs.py +231 -0
  13. cuqi/legacy/sampler/_hmc.py +335 -0
  14. cuqi/legacy/sampler/_langevin_algorithm.py +198 -0
  15. cuqi/legacy/sampler/_laplace_approximation.py +184 -0
  16. cuqi/legacy/sampler/_mh.py +190 -0
  17. cuqi/legacy/sampler/_pcn.py +244 -0
  18. cuqi/legacy/sampler/_rto.py +284 -0
  19. cuqi/legacy/sampler/_sampler.py +182 -0
  20. cuqi/problem/_problem.py +87 -80
  21. cuqi/sampler/__init__.py +120 -8
  22. cuqi/sampler/_conjugate.py +376 -35
  23. cuqi/sampler/_conjugate_approx.py +40 -16
  24. cuqi/sampler/_cwmh.py +132 -138
  25. cuqi/{experimental/mcmc → sampler}/_direct.py +1 -1
  26. cuqi/sampler/_gibbs.py +269 -130
  27. cuqi/sampler/_hmc.py +328 -201
  28. cuqi/sampler/_langevin_algorithm.py +282 -98
  29. cuqi/sampler/_laplace_approximation.py +87 -117
  30. cuqi/sampler/_mh.py +47 -157
  31. cuqi/sampler/_pcn.py +56 -211
  32. cuqi/sampler/_rto.py +206 -140
  33. cuqi/sampler/_sampler.py +540 -135
  34. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/METADATA +1 -1
  35. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/RECORD +38 -37
  36. cuqi/experimental/mcmc/__init__.py +0 -122
  37. cuqi/experimental/mcmc/_conjugate.py +0 -396
  38. cuqi/experimental/mcmc/_conjugate_approx.py +0 -76
  39. cuqi/experimental/mcmc/_cwmh.py +0 -190
  40. cuqi/experimental/mcmc/_gibbs.py +0 -366
  41. cuqi/experimental/mcmc/_hmc.py +0 -462
  42. cuqi/experimental/mcmc/_langevin_algorithm.py +0 -382
  43. cuqi/experimental/mcmc/_laplace_approximation.py +0 -154
  44. cuqi/experimental/mcmc/_mh.py +0 -80
  45. cuqi/experimental/mcmc/_pcn.py +0 -89
  46. cuqi/experimental/mcmc/_rto.py +0 -350
  47. cuqi/experimental/mcmc/_sampler.py +0 -582
  48. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/WHEEL +0 -0
  49. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/licenses/LICENSE +0 -0
  50. {cuqipy-1.3.0.post0.dev401.dist-info → cuqipy-1.4.0.post0.dev41.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,231 @@
1
+ from cuqi.distribution import JointDistribution
2
+ from cuqi.legacy.sampler import Sampler
3
+ from cuqi.samples import Samples
4
+ from typing import Dict, Union
5
+ import numpy as np
6
+ import sys
7
+ import warnings
8
+
9
+ class Gibbs:
10
+ """
11
+ Gibbs sampler for sampling a joint distribution.
12
+
13
+ Gibbs sampling samples the variables of the distribution sequentially,
14
+ one variable at a time. When a variable represents a random vector, the
15
+ whole vector is sampled simultaneously.
16
+
17
+ The sampling of each variable is done by sampling from the conditional
18
+ distribution of that variable given the values of the other variables.
19
+ This is often a very efficient way of sampling from a joint distribution
20
+ if the conditional distributions are easy to sample from.
21
+
22
+ Parameters
23
+ ----------
24
+ target : cuqi.distribution.JointDistribution
25
+ Target distribution to sample from.
26
+
27
+ sampling_strategy : dict
28
+ Dictionary of sampling strategies for each parameter.
29
+ Keys are parameter names.
30
+ Values are sampler objects.
31
+
32
+ Example
33
+ -------
34
+ .. code-block:: python
35
+
36
+ import cuqi
37
+ import numpy as np
38
+
39
+ # Model and data
40
+ A, y_obs, probinfo = cuqi.testproblem.Deconvolution1D(phantom='square').get_components()
41
+ n = A.domain_dim
42
+
43
+ # Define distributions
44
+ d = cuqi.distribution.Gamma(1, 1e-4)
45
+ l = cuqi.distribution.Gamma(1, 1e-4)
46
+ x = cuqi.distribution.GMRF(np.zeros(n), lambda d: d)
47
+ y = cuqi.distribution.Gaussian(A, lambda l: 1/l)
48
+
49
+ # Combine into a joint distribution and create posterior
50
+ joint = cuqi.distribution.JointDistribution(d, l, x, y)
51
+ posterior = joint(y=y_obs)
52
+
53
+ # Define sampling strategy
54
+ sampling_strategy = {
55
+ 'x': cuqi.legacy.sampler.LinearRTO,
56
+ ('d', 'l'): cuqi.legacy.sampler.Conjugate,
57
+ }
58
+
59
+ # Define Gibbs sampler
60
+ sampler = cuqi.legacy.sampler.Gibbs(posterior, sampling_strategy)
61
+
62
+ # Run sampler
63
+ samples = sampler.sample(Ns=1000, Nb=200)
64
+
65
+ # Plot results
66
+ samples['x'].plot_ci(exact=probinfo.exactSolution)
67
+ samples['d'].plot_trace(figsize=(8,2))
68
+ samples['l'].plot_trace(figsize=(8,2))
69
+
70
+ """
71
+
72
+ def __init__(self, target: JointDistribution, sampling_strategy: Dict[Union[str,tuple], Sampler]):
73
+
74
+ warnings.warn(f"\nYou are using the legacy sampler '{self.__class__.__name__}'.\n"
75
+ f"This will be removed in a future release of CUQIpy.\n"
76
+ f"Please consider using the new samplers in the 'cuqi.sampler' module.\n", UserWarning, stacklevel=2)
77
+
78
+ # Store target and allow conditioning to reduce to a single density
79
+ self.target = target() # Create a copy of target distribution (to avoid modifying the original)
80
+
81
+ # Parse samplers and split any keys that are tuple into separate keys
82
+ self.samplers = {}
83
+ for par_name in sampling_strategy.keys():
84
+ if isinstance(par_name, tuple):
85
+ for par_name_ in par_name:
86
+ self.samplers[par_name_] = sampling_strategy[par_name]
87
+ else:
88
+ self.samplers[par_name] = sampling_strategy[par_name]
89
+
90
+ # Store parameter names
91
+ self.par_names = self.target.get_parameter_names()
92
+
93
+ # ------------ Public methods ------------
94
+ def sample(self, Ns, Nb=0):
95
+ """ Sample from target distribution """
96
+
97
+ # Initial points
98
+ current_samples = self._get_initial_points()
99
+
100
+ # Compute how many samples were already taken previously
101
+ at_Nb = self._Nb
102
+ at_Ns = self._Ns
103
+
104
+ # Allocate memory for samples
105
+ self._allocate_samples_warmup(Nb)
106
+ self._allocate_samples(Ns)
107
+
108
+ # Sample tuning phase
109
+ for i in range(at_Nb, at_Nb+Nb):
110
+ current_samples = self.step_tune(current_samples)
111
+ self._store_samples(self.samples_warmup, current_samples, i)
112
+ self._print_progress(i+1+at_Nb, at_Nb+Nb, 'Warmup')
113
+
114
+ # Sample phase
115
+ for i in range(at_Ns, at_Ns+Ns):
116
+ current_samples = self.step(current_samples)
117
+ self._store_samples(self.samples, current_samples, i)
118
+ self._print_progress(i+1, at_Ns+Ns, 'Sample')
119
+
120
+ # Convert to samples objects and return
121
+ return self._convert_to_Samples(self.samples)
122
+
123
+ def step(self, current_samples):
124
+ """ Sequentially go through all parameters and sample them conditionally on each other """
125
+
126
+ # Extract par names
127
+ par_names = self.par_names
128
+
129
+ # Sample from each conditional distribution
130
+ for par_name in par_names:
131
+
132
+ # Dict of all other parameters to condition on
133
+ other_params = {par_name_: current_samples[par_name_] for par_name_ in par_names if par_name_ != par_name}
134
+
135
+ # Set up sampler for current conditional distribution
136
+ sampler = self.samplers[par_name](self.target(**other_params))
137
+
138
+ # Take a MCMC step
139
+ current_samples[par_name] = sampler.step(current_samples[par_name])
140
+
141
+ # Ensure even 1-dimensional samples are 1D arrays
142
+ current_samples[par_name] = current_samples[par_name].reshape(-1)
143
+
144
+ return current_samples
145
+
146
+ def step_tune(self, current_samples):
147
+ """ Perform a single MCMC step for each parameter and tune the sampler """
148
+ # Not implemented. No tuning happening here yet. Requires samplers to be able to be modified after initialization.
149
+ return self.step(current_samples)
150
+
151
+ # ------------ Private methods ------------
152
+ def _allocate_samples(self, Ns):
153
+ """ Allocate memory for samples """
154
+ # Allocate memory for samples
155
+ samples = {}
156
+ for par_name in self.par_names:
157
+ samples[par_name] = np.zeros((self.target.get_density(par_name).dim, Ns))
158
+
159
+ # Store samples in self
160
+ if hasattr(self, 'samples'):
161
+ # Append to existing samples (This makes a copy)
162
+ for par_name in self.par_names:
163
+ samples[par_name] = np.hstack((self.samples[par_name], samples[par_name]))
164
+ self.samples = samples
165
+
166
+ def _allocate_samples_warmup(self, Nb):
167
+ """ Allocate memory for samples """
168
+
169
+ # If we already have warmup samples and more are requested raise error
170
+ if hasattr(self, 'samples_warmup') and Nb != 0:
171
+ raise ValueError('Sampler already has run warmup phase. Cannot run warmup phase again.')
172
+
173
+ # Allocate memory for samples
174
+ samples = {}
175
+ for par_name in self.par_names:
176
+ samples[par_name] = np.zeros((self.target.get_density(par_name).dim, Nb))
177
+ self.samples_warmup = samples
178
+
179
+ def _get_initial_points(self):
180
+ """ Get initial points for each parameter """
181
+ initial_points = {}
182
+ for par_name in self.par_names:
183
+ if hasattr(self, 'samples'):
184
+ initial_points[par_name] = self.samples[par_name][:, -1]
185
+ elif hasattr(self, 'samples_warmup'):
186
+ initial_points[par_name] = self.samples_warmup[par_name][:, -1]
187
+ elif hasattr(self.target.get_density(par_name), 'init_point'):
188
+ initial_points[par_name] = self.target.get_density(par_name).init_point
189
+ else:
190
+ initial_points[par_name] = np.ones(self.target.get_density(par_name).dim)
191
+ return initial_points
192
+
193
+ def _store_samples(self, samples, current_samples, i):
194
+ """ Store current samples at index i of samples dict """
195
+ for par_name in self.par_names:
196
+ samples[par_name][:, i] = current_samples[par_name]
197
+
198
+ def _convert_to_Samples(self, samples):
199
+ """ Convert each parameter in samples dict to cuqi.samples.Samples object with correct geometry """
200
+ samples_object = {}
201
+ for par_name in self.par_names:
202
+ samples_object[par_name] = Samples(samples[par_name], self.target.get_density(par_name).geometry)
203
+ return samples_object
204
+
205
+ def _print_progress(self, s, Ns, phase):
206
+ """Prints sampling progress"""
207
+ if Ns < 2: # Don't print progress if only one sample
208
+ return
209
+ if (s % (max(Ns//100,1))) == 0:
210
+ msg = f'{phase} {s} / {Ns}'
211
+ sys.stdout.write('\r'+msg)
212
+ if s==Ns:
213
+ msg = f'{phase} {s} / {Ns}'
214
+ sys.stdout.write('\r'+msg+'\n')
215
+
216
+ # ------------ Private properties ------------
217
+ @property
218
+ def _Ns(self):
219
+ """ Number of samples already taken """
220
+ if hasattr(self, 'samples'):
221
+ return self.samples[self.par_names[0]].shape[-1]
222
+ else:
223
+ return 0
224
+
225
+ @property
226
+ def _Nb(self):
227
+ """ Number of samples already taken in warmup phase """
228
+ if hasattr(self, 'samples_warmup'):
229
+ return self.samples_warmup[self.par_names[0]].shape[-1]
230
+ else:
231
+ return 0
@@ -0,0 +1,335 @@
1
+ import numpy as np
2
+ from cuqi.legacy.sampler import Sampler
3
+
4
+
5
+ # another implementation is in https://github.com/mfouesneau/NUTS
6
+ class NUTS(Sampler):
7
+ """No-U-Turn Sampler (Hoffman and Gelman, 2014).
8
+
9
+ Samples a distribution given its logpdf and gradient using a Hamiltonian Monte Carlo (HMC) algorithm with automatic parameter tuning.
10
+
11
+ For more details see: See Hoffman, M. D., & Gelman, A. (2014). The no-U-turn sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. Journal of Machine Learning Research, 15, 1593-1623.
12
+
13
+ Parameters
14
+ ----------
15
+
16
+ target : `cuqi.distribution.Distribution`
17
+ The target distribution to sample. Must have logpdf and gradient method. Custom logpdfs and gradients are supported by using a :class:`cuqi.distribution.UserDefinedDistribution`.
18
+
19
+ x0 : ndarray
20
+ Initial parameters. *Optional*
21
+
22
+ max_depth : int
23
+ Maximum depth of the tree.
24
+
25
+ adapt_step_size : Bool or float
26
+ Whether to adapt the step size.
27
+ If True, the step size is adapted automatically.
28
+ If False, the step size is fixed to the initially estimated value.
29
+ If set to a scalar, the step size will be given by user and not adapted.
30
+
31
+ opt_acc_rate : float
32
+ The optimal acceptance rate to reach if using adaptive step size.
33
+ Suggested values are 0.6 (default) or 0.8 (as in stan).
34
+
35
+ callback : callable, *Optional*
36
+ If set this function will be called after every sample.
37
+ The signature of the callback function is `callback(sample, sample_index)`,
38
+ where `sample` is the current sample and `sample_index` is the index of the sample.
39
+ An example is shown in demos/demo31_callback.py.
40
+
41
+ Example
42
+ -------
43
+ .. code-block:: python
44
+
45
+ # Import cuqi
46
+ import cuqi
47
+
48
+ # Define a target distribution
49
+ tp = cuqi.testproblem.WangCubic()
50
+ target = tp.posterior
51
+
52
+ # Set up sampler
53
+ sampler = cuqi.legacy.sampler.NUTS(target)
54
+
55
+ # Sample
56
+ samples = sampler.sample(10000, 5000)
57
+
58
+ # Plot samples
59
+ samples.plot_pair()
60
+
61
+ After running the NUTS sampler, run diagnostics can be accessed via the
62
+ following attributes:
63
+
64
+ .. code-block:: python
65
+
66
+ # Number of tree nodes created each NUTS iteration
67
+ sampler.num_tree_node_list
68
+
69
+ # Step size used in each NUTS iteration
70
+ sampler.epsilon_list
71
+
72
+ # Suggested step size during adaptation (the value of this step size is
73
+ # only used after adaptation). The suggested step size is None if
74
+ # adaptation is not requested.
75
+ sampler.epsilon_bar_list
76
+
77
+ # Additionally, iterations' number can be accessed via
78
+ sampler.iteration_list
79
+
80
+ """
81
+ def __init__(self, target, x0=None, max_depth=15, adapt_step_size=True, opt_acc_rate=0.6, **kwargs):
82
+ super().__init__(target, x0=x0, **kwargs)
83
+ self.max_depth = max_depth
84
+ self.adapt_step_size = adapt_step_size
85
+ self.opt_acc_rate = opt_acc_rate
86
+ # if this flag is True, the samples and the burn-in will be returned
87
+ # otherwise, the burn-in will be truncated
88
+ self._return_burnin = False
89
+
90
+ # NUTS run diagnostic
91
+ # number of tree nodes created each NUTS iteration
92
+ self._num_tree_node = 0
93
+ # Create lists to store NUTS run diagnostics
94
+ self._create_run_diagnostic_attributes()
95
+
96
+ def _create_run_diagnostic_attributes(self):
97
+ """A method to create attributes to store NUTS run diagnostic."""
98
+ self._reset_run_diagnostic_attributes()
99
+
100
+ def _reset_run_diagnostic_attributes(self):
101
+ """A method to reset attributes to store NUTS run diagnostic."""
102
+ # NUTS iterations
103
+ self.iteration_list = []
104
+ # List to store number of tree nodes created each NUTS iteration
105
+ self.num_tree_node_list = []
106
+ # List of step size used in each NUTS iteration
107
+ self.epsilon_list = []
108
+ # List of burn-in step size suggestion during adaptation
109
+ # only used when adaptation is done
110
+ # remains fixed after adaptation (after burn-in)
111
+ self.epsilon_bar_list = []
112
+
113
+ def _update_run_diagnostic_attributes(self, k, n_tree, eps, eps_bar):
114
+ """A method to update attributes to store NUTS run diagnostic."""
115
+ # Store the current iteration number k
116
+ self.iteration_list.append(k)
117
+ # Store the number of tree nodes created in iteration k
118
+ self.num_tree_node_list.append(n_tree)
119
+ # Store the step size used in iteration k
120
+ self.epsilon_list.append(eps)
121
+ # Store the step size suggestion during adaptation in iteration k
122
+ self.epsilon_bar_list.append(eps_bar)
123
+
124
+ def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
125
+ return self.target.logd(x), self.target.gradient(x)
126
+
127
+ def _sample_adapt(self, N, Nb):
128
+ return self._sample(N, Nb)
129
+
130
+ def _sample(self, N, Nb):
131
+ # Reset run diagnostic attributes
132
+ self._reset_run_diagnostic_attributes()
133
+
134
+ if self.adapt_step_size is True and Nb == 0:
135
+ raise ValueError("Adaptive step size is True but number of burn-in steps is 0. Please set Nb > 0.")
136
+
137
+ # Allocation
138
+ Ns = Nb+N # total number of chains
139
+ theta = np.empty((self.dim, Ns))
140
+ joint_eval = np.empty(Ns)
141
+ step_sizes = np.empty(Ns)
142
+
143
+ # Initial state
144
+ theta[:, 0] = self.x0
145
+ joint_eval[0], grad = self._nuts_target(self.x0)
146
+
147
+ # Step size variables
148
+ epsilon, epsilon_bar = None, None
149
+
150
+ # parameters dual averaging
151
+ if (self.adapt_step_size == True):
152
+ epsilon = self._FindGoodEpsilon(theta[:, 0], joint_eval[0], grad)
153
+ mu = np.log(10*epsilon)
154
+ gamma, t_0, kappa = 0.05, 10, 0.75 # kappa in (0.5, 1]
155
+ epsilon_bar, H_bar = 1, 0
156
+ delta = self.opt_acc_rate # https://mc-stan.org/docs/2_18/reference-manual/hmc-algorithm-parameters.html
157
+ step_sizes[0] = epsilon
158
+ elif (self.adapt_step_size == False):
159
+ epsilon = self._FindGoodEpsilon(theta[:, 0], joint_eval[0], grad)
160
+ else:
161
+ epsilon = self.adapt_step_size # if scalar then user specifies the step size
162
+
163
+ # run NUTS
164
+ for k in range(1, Ns):
165
+ # reset number of tree nodes for each iteration
166
+ self._num_tree_node = 0
167
+
168
+ theta_k, joint_k = theta[:, k-1], joint_eval[k-1] # initial position (parameters)
169
+ r_k = self._Kfun(1, 'sample') # resample momentum vector
170
+ Ham = joint_k - self._Kfun(r_k, 'eval') # Hamiltonian
171
+
172
+ # slice variable
173
+ log_u = Ham - np.random.exponential(1, size=1) # u = np.log(np.random.uniform(0, np.exp(H)))
174
+
175
+ # initialization
176
+ j, s, n = 0, 1, 1
177
+ theta[:, k], joint_eval[k] = theta_k, joint_k
178
+ theta_minus, theta_plus = np.copy(theta_k), np.copy(theta_k)
179
+ grad_minus, grad_plus = np.copy(grad), np.copy(grad)
180
+ r_minus, r_plus = np.copy(r_k), np.copy(r_k)
181
+
182
+ # run NUTS
183
+ while (s == 1) and (j <= self.max_depth):
184
+ # sample a direction
185
+ v = int(2*(np.random.rand() < 0.5)-1)
186
+
187
+ # build tree: doubling procedure
188
+ if (v == -1):
189
+ theta_minus, r_minus, grad_minus, _, _, _, \
190
+ theta_prime, joint_prime, grad_prime, n_prime, s_prime, alpha, n_alpha = \
191
+ self._BuildTree(theta_minus, r_minus, grad_minus, Ham, log_u, v, j, epsilon)
192
+ else:
193
+ _, _, _, theta_plus, r_plus, grad_plus, \
194
+ theta_prime, joint_prime, grad_prime, n_prime, s_prime, alpha, n_alpha = \
195
+ self._BuildTree(theta_plus, r_plus, grad_plus, Ham, log_u, v, j, epsilon)
196
+
197
+ # Metropolis step
198
+ alpha2 = min(1, (n_prime/n)) #min(0, np.log(n_p) - np.log(n))
199
+ if (s_prime == 1) and (np.random.rand() <= alpha2):
200
+ theta[:, k] = theta_prime
201
+ joint_eval[k] = joint_prime
202
+ grad = np.copy(grad_prime)
203
+
204
+ # update number of particles, tree level, and stopping criterion
205
+ n += n_prime
206
+ dtheta = theta_plus - theta_minus
207
+ s = s_prime * int((dtheta @ r_minus.T) >= 0) * int((dtheta @ r_plus.T) >= 0)
208
+ j += 1
209
+
210
+ # update run diagnostic attributes
211
+ self._update_run_diagnostic_attributes(
212
+ k, self._num_tree_node, epsilon, epsilon_bar)
213
+
214
+ # adapt epsilon during burn-in using dual averaging
215
+ if (k <= Nb) and (self.adapt_step_size == True):
216
+ eta1 = 1/(k + t_0)
217
+ H_bar = (1-eta1)*H_bar + eta1*(delta - (alpha/n_alpha))
218
+ epsilon = np.exp(mu - (np.sqrt(k)/gamma)*H_bar)
219
+ eta = k**(-kappa)
220
+ epsilon_bar = np.exp(eta*np.log(epsilon) + (1-eta)*np.log(epsilon_bar))
221
+ elif (k == Nb+1) and (self.adapt_step_size == True):
222
+ epsilon = epsilon_bar # fix epsilon after burn-in
223
+ step_sizes[k] = epsilon
224
+
225
+ # msg
226
+ self._print_progress(k+1, Ns) #k+1 is the sample number, k is index assuming x0 is the first sample
227
+ self._call_callback(theta[:, k], k)
228
+
229
+ if np.isnan(joint_eval[k]):
230
+ raise NameError('NaN potential func')
231
+
232
+ # apply burn-in
233
+ if not self._return_burnin:
234
+ theta = theta[:, Nb:]
235
+ joint_eval = joint_eval[Nb:]
236
+ return theta, joint_eval, step_sizes
237
+
238
+ #=========================================================================
239
+ # auxiliary standard Gaussian PDF: kinetic energy function
240
+ # d_log_2pi = d*np.log(2*np.pi)
241
+ def _Kfun(self, r, flag):
242
+ if flag == 'eval': # evaluate
243
+ return 0.5*(r.T @ r) #+ d_log_2pi
244
+ if flag == 'sample': # sample
245
+ return np.random.standard_normal(size=self.dim)
246
+
247
+ #=========================================================================
248
+ def _FindGoodEpsilon(self, theta, joint, grad, epsilon=1):
249
+ r = self._Kfun(1, 'sample') # resample a momentum
250
+ Ham = joint - self._Kfun(r, 'eval') # initial Hamiltonian
251
+ _, r_prime, joint_prime, grad_prime = self._Leapfrog(theta, r, grad, epsilon)
252
+
253
+ # trick to make sure the step is not huge, leading to infinite values of the likelihood
254
+ k = 1
255
+ while np.isinf(joint_prime) or np.isinf(grad_prime).any():
256
+ k *= 0.5
257
+ _, r_prime, joint_prime, grad_prime = self._Leapfrog(theta, r, grad, epsilon*k)
258
+ epsilon = 0.5*k*epsilon
259
+
260
+ # doubles/halves the value of epsilon until the accprob of the Langevin proposal crosses 0.5
261
+ Ham_prime = joint_prime - self._Kfun(r_prime, 'eval')
262
+ log_ratio = Ham_prime - Ham
263
+ a = 1 if log_ratio > np.log(0.5) else -1
264
+ while (a*log_ratio > -a*np.log(2)):
265
+ epsilon = (2**a)*epsilon
266
+ _, r_prime, joint_prime, _ = self._Leapfrog(theta, r, grad, epsilon)
267
+ Ham_prime = joint_prime - self._Kfun(r_prime, 'eval')
268
+ log_ratio = Ham_prime - Ham
269
+ return epsilon
270
+
271
+ #=========================================================================
272
+ def _Leapfrog(self, theta_old, r_old, grad_old, epsilon):
273
+ # symplectic integrator: trajectories preserve phase space volumen
274
+ r_new = r_old + 0.5*epsilon*grad_old # half-step
275
+ theta_new = theta_old + epsilon*r_new # full-step
276
+ joint_new, grad_new = self._nuts_target(theta_new) # new gradient
277
+ r_new += 0.5*epsilon*grad_new # half-step
278
+ return theta_new, r_new, joint_new, grad_new
279
+
280
+ #=========================================================================
281
+ # @functools.lru_cache(maxsize=128)
282
+ def _BuildTree(self, theta, r, grad, Ham, log_u, v, j, epsilon, Delta_max=1000):
283
+ # Increment the number of tree nodes counter
284
+ self._num_tree_node += 1
285
+
286
+ if (j == 0): # base case
287
+ # single leapfrog step in the direction v
288
+ theta_prime, r_prime, joint_prime, grad_prime = self._Leapfrog(theta, r, grad, v*epsilon)
289
+ Ham_prime = joint_prime - self._Kfun(r_prime, 'eval') # Hamiltonian eval
290
+ n_prime = int(log_u <= Ham_prime) # if particle is in the slice
291
+ s_prime = int(log_u < Delta_max + Ham_prime) # check U-turn
292
+ #
293
+ diff_Ham = Ham_prime - Ham
294
+
295
+ # Compute the acceptance probability
296
+ # alpha_prime = min(1, np.exp(diff_Ham))
297
+ # written in a stable way to avoid overflow when computing
298
+ # exp(diff_Ham) for large values of diff_Ham
299
+ alpha_prime = 1 if diff_Ham > 0 else np.exp(diff_Ham)
300
+ n_alpha_prime = 1
301
+ #
302
+ theta_minus, theta_plus = theta_prime, theta_prime
303
+ r_minus, r_plus = r_prime, r_prime
304
+ grad_minus, grad_plus = grad_prime, grad_prime
305
+ else:
306
+ # recursion: build the left/right subtrees
307
+ theta_minus, r_minus, grad_minus, theta_plus, r_plus, grad_plus, \
308
+ theta_prime, joint_prime, grad_prime, n_prime, s_prime, alpha_prime, n_alpha_prime = \
309
+ self._BuildTree(theta, r, grad, Ham, log_u, v, j-1, epsilon)
310
+ if (s_prime == 1): # do only if the stopping criteria does not verify at the first subtree
311
+ if (v == -1):
312
+ theta_minus, r_minus, grad_minus, _, _, _, \
313
+ theta_2prime, joint_2prime, grad_2prime, n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
314
+ self._BuildTree(theta_minus, r_minus, grad_minus, Ham, log_u, v, j-1, epsilon)
315
+ else:
316
+ _, _, _, theta_plus, r_plus, grad_plus, \
317
+ theta_2prime, joint_2prime, grad_2prime, n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
318
+ self._BuildTree(theta_plus, r_plus, grad_plus, Ham, log_u, v, j-1, epsilon)
319
+
320
+ # Metropolis step
321
+ alpha2 = n_2prime / max(1, (n_prime + n_2prime))
322
+ if (np.random.rand() <= alpha2):
323
+ theta_prime = np.copy(theta_2prime)
324
+ joint_prime = np.copy(joint_2prime)
325
+ grad_prime = np.copy(grad_2prime)
326
+
327
+ # update number of particles and stopping criterion
328
+ alpha_prime += alpha_2prime
329
+ n_alpha_prime += n_alpha_2prime
330
+ dtheta = theta_plus - theta_minus
331
+ s_prime = s_2prime * int((dtheta@r_minus.T)>=0) * int((dtheta@r_plus.T)>=0)
332
+ n_prime += n_2prime
333
+ return theta_minus, r_minus, grad_minus, theta_plus, r_plus, grad_plus, \
334
+ theta_prime, joint_prime, grad_prime, n_prime, s_prime, alpha_prime, n_alpha_prime
335
+