CUQIpy 1.0.0.post0.dev127__py3-none-any.whl → 1.0.0.post0.dev180__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of CUQIpy might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: CUQIpy
3
- Version: 1.0.0.post0.dev127
3
+ Version: 1.0.0.post0.dev180
4
4
  Summary: Computational Uncertainty Quantification for Inverse problems in Python
5
5
  Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
6
6
  License: Apache License
@@ -1,6 +1,6 @@
1
1
  cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
2
2
  cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
3
- cuqi/_version.py,sha256=2GhcRA-eAhvzsZ7hWqzmLdIV7UEkRWfrTcDXqFBbrWE,510
3
+ cuqi/_version.py,sha256=nZ9Rz10NtUnQZt09kW9YMygjxV0G-qsipgAL3xBFeAU,510
4
4
  cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
5
5
  cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
6
6
  cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
@@ -32,13 +32,14 @@ cuqi/distribution/_normal.py,sha256=UeoTtGDT7YSf4ZNo2amlVF9K-YQpYbf8q76jcRJTVFw,
32
32
  cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7ig,5018
33
33
  cuqi/distribution/_uniform.py,sha256=7xJmCZH_LPhuGkwEDGh-_CTtzcWKrXMOxtTJUFb7Ydo,1607
34
34
  cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
35
- cuqi/experimental/mcmc/__init__.py,sha256=vVnohcm4EIUwbp1sr3LbB0BkXO8jyZsbiKMJmIgetYY,314
36
- cuqi/experimental/mcmc/_cwmh.py,sha256=G-8YjMqPraZm1Pm3n6scFkpa65gdtI1WTQxlL21etEI,8066
37
- cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=ckVHDXLaw8hsUaOAFAEs7bL2Ny7W1QBKSc4AAC-TCis,9986
38
- cuqi/experimental/mcmc/_mh.py,sha256=AslackZJ3hPUNQfy70Fh9WoRPtcsHGqulI0LQgGHBus,3477
39
- cuqi/experimental/mcmc/_pcn.py,sha256=ma3pFqFgOmE7woZ41B5CGccKEuaacJPTmKvSEQhvtzs,3981
40
- cuqi/experimental/mcmc/_rto.py,sha256=nQHpSnUlE65TXBnGFk88JLawR44af2Rtle0oFCGYgaQ,11540
41
- cuqi/experimental/mcmc/_sampler.py,sha256=s-15bElbSZFHJlaV9gmiwk-UAneQEU9W-y5tLy_NMCU,11197
35
+ cuqi/experimental/mcmc/__init__.py,sha256=S4aXYpnO75HQcwDYfr1-ki8UvlenPDXxshES5avtBF0,340
36
+ cuqi/experimental/mcmc/_cwmh.py,sha256=yRlTk5a1QYfH3JyCecfOOTeDf-4-tmJ3Tl2Bc3pyp1Y,7336
37
+ cuqi/experimental/mcmc/_hmc.py,sha256=qqAyoAajLE_JenYMgAbD3tknuEf75AJu-ufF69GKGk4,19384
38
+ cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=MX48u3GYgCckB6Q5h5kXr_qdIaLQH2toOG5u29OY7gk,8245
39
+ cuqi/experimental/mcmc/_mh.py,sha256=aIV1Ntq0EAq3QJ1_X-DbP7eDAL-d_Or7d3RUO-R48I4,3090
40
+ cuqi/experimental/mcmc/_pcn.py,sha256=3M8zhQGQa53Gz04AkC8wJM61_5rIjGVnhPefi8m4dbY,3531
41
+ cuqi/experimental/mcmc/_rto.py,sha256=jSPznr34XPfWM6LysWIiN4hE-vtyti3cHyvzy9ruykg,11349
42
+ cuqi/experimental/mcmc/_sampler.py,sha256=_5Uo2F-Mta46w3lo7WBVNwvTLYhES_BzMTJrKxA00c8,14861
42
43
  cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
43
44
  cuqi/geometry/_geometry.py,sha256=WYFC-4_VBTW73b2ldsnfGYKvdSiCE8plr89xTSmkadg,46804
44
45
  cuqi/implicitprior/__init__.py,sha256=ZRZ9fgxgEl5n0A9F7WCl1_jid-GUiC8ZLkyTmGQmFlY,100
@@ -59,7 +60,7 @@ cuqi/sampler/_conjugate.py,sha256=ztmUR3V3qZk9zelKx48ULnmMs_zKTDUfohc256VOIe8,27
59
60
  cuqi/sampler/_conjugate_approx.py,sha256=xX-X71EgxGnZooOY6CIBhuJTs3dhcKfoLnoFxX3CO2g,1938
60
61
  cuqi/sampler/_cwmh.py,sha256=VlAVT1SXQU0yD5ZeR-_ckWvX-ifJrMweFFdFbxdfB_k,7775
61
62
  cuqi/sampler/_gibbs.py,sha256=N7qcePwMkRtxINN5JF0FaMIdDCXZGqsfKjfha_KHCck,8627
62
- cuqi/sampler/_hmc.py,sha256=76nPkvNU0wLSg4qvm-1s048MzQasl5Qk94sHpyeJ5hM,14819
63
+ cuqi/sampler/_hmc.py,sha256=EUTefZir-wapoZ7OZFb5M5vayL8z6XksZRMY1BpbuXc,15027
63
64
  cuqi/sampler/_langevin_algorithm.py,sha256=o5EyvaR6QGAD7LKwXVRC3WwAP5IYJf5GoMVWl9DrfOA,7861
64
65
  cuqi/sampler/_laplace_approximation.py,sha256=u018Z5eqlcq_cIwD9yNOaA15dLQE_vUWaee5Xp8bcjg,6454
65
66
  cuqi/sampler/_mh.py,sha256=V5tIdn-KdfWo4J_Nbf-AH6XwKWblWUyc4BeuSikUHsE,7062
@@ -75,8 +76,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
75
76
  cuqi/utilities/__init__.py,sha256=EfxHLdsyDNugbmbzs43nV_AeKcycM9sVBjG9WZydagA,351
76
77
  cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
77
78
  cuqi/utilities/_utilities.py,sha256=At3DOXRdF3GwLkVcM2FXooGyjAGfPkIM0bRzhTfLmWk,8046
78
- CUQIpy-1.0.0.post0.dev127.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
79
- CUQIpy-1.0.0.post0.dev127.dist-info/METADATA,sha256=bHtJEgkhpP_C50jwE7SwIpOY0Hm48V2qMKy_v6k0SiU,18393
80
- CUQIpy-1.0.0.post0.dev127.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
81
- CUQIpy-1.0.0.post0.dev127.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
82
- CUQIpy-1.0.0.post0.dev127.dist-info/RECORD,,
79
+ CUQIpy-1.0.0.post0.dev180.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
80
+ CUQIpy-1.0.0.post0.dev180.dist-info/METADATA,sha256=YSVheKM46Y6Al3zXEu11wlCkJxAasZw2onXB0nTLN08,18393
81
+ CUQIpy-1.0.0.post0.dev180.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
82
+ CUQIpy-1.0.0.post0.dev180.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
83
+ CUQIpy-1.0.0.post0.dev180.dist-info/RECORD,,
cuqi/_version.py CHANGED
@@ -8,11 +8,11 @@ import json
8
8
 
9
9
  version_json = '''
10
10
  {
11
- "date": "2024-04-05T11:26:24+0200",
11
+ "date": "2024-04-22T18:01:04+0200",
12
12
  "dirty": false,
13
13
  "error": null,
14
- "full-revisionid": "476cc08200034336b46144fac4c819f8298fa587",
15
- "version": "1.0.0.post0.dev127"
14
+ "full-revisionid": "be4b485322b1be52f78dfe6d03694d6ba2000b11",
15
+ "version": "1.0.0.post0.dev180"
16
16
  }
17
17
  ''' # END VERSION_JSON
18
18
 
@@ -6,3 +6,4 @@ from ._mh import MHNew
6
6
  from ._pcn import pCNNew
7
7
  from ._rto import LinearRTONew, RegularizedLinearRTONew
8
8
  from ._cwmh import CWMHNew
9
+ from ._hmc import NUTSNew
@@ -144,7 +144,7 @@ class CWMHNew(ProposalBasedSamplerNew):
144
144
 
145
145
  # Propose a sample x_all_components from the proposal distribution
146
146
  # for all the components
147
- target_eval_t = self.current_target
147
+ target_eval_t = self.current_target_logd
148
148
  if isinstance(self.proposal,cuqi.distribution.Distribution):
149
149
  x_all_components = self.proposal(
150
150
  location= self.current_point, scale=self.scale).sample()
@@ -175,7 +175,7 @@ class CWMHNew(ProposalBasedSamplerNew):
175
175
 
176
176
  x_star = x_t.copy()
177
177
 
178
- self.current_target = target_eval_t
178
+ self.current_target_logd = target_eval_t
179
179
  self.current_point = x_t
180
180
 
181
181
  return acc
@@ -199,23 +199,3 @@ class CWMHNew(ProposalBasedSamplerNew):
199
199
  # Update the scale parameter
200
200
  self.scale = np.minimum(scale_temp, np.ones(self.dim))
201
201
  self._scale_temp = scale_temp
202
-
203
- def get_state(self):
204
- current_point = self.current_point
205
- if isinstance(current_point, CUQIarray):
206
- current_point = current_point.to_numpy()
207
-
208
- return {'sampler_type': 'CWMH',
209
- 'current_point': current_point,
210
- 'current_target': self.current_target,
211
- 'scale': self.scale}
212
-
213
- def set_state(self, state):
214
- current_point = state['current_point']
215
- if not isinstance(current_point, CUQIarray):
216
- current_point = CUQIarray(current_point,
217
- geometry=self.target.geometry)
218
-
219
- self.current_point = current_point
220
- self.current_target = state['current_target']
221
- self.scale = state['scale']
@@ -0,0 +1,470 @@
1
+ import numpy as np
2
+ import numpy as np
3
+ from cuqi.experimental.mcmc import SamplerNew
4
+ from cuqi.array import CUQIarray
5
+ from numbers import Number
6
+
7
+ class NUTSNew(SamplerNew):
8
+ """No-U-Turn Sampler (Hoffman and Gelman, 2014).
9
+
10
+ Samples a distribution given its logpdf and gradient using a Hamiltonian
11
+ Monte Carlo (HMC) algorithm with automatic parameter tuning.
12
+
13
+ For more details see: See Hoffman, M. D., & Gelman, A. (2014). The no-U-turn
14
+ sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. Journal
15
+ of Machine Learning Research, 15, 1593-1623.
16
+
17
+ Parameters
18
+ ----------
19
+ target : `cuqi.distribution.Distribution`
20
+ The target distribution to sample. Must have logpdf and gradient method.
21
+ Custom logpdfs and gradients are supported by using a
22
+ :class:`cuqi.distribution.UserDefinedDistribution`.
23
+
24
+ initial_point : ndarray
25
+ Initial parameters. *Optional*. If not provided, the initial point is
26
+ an array of ones.
27
+
28
+ max_depth : int
29
+ Maximum depth of the tree >=0 and the default is 15.
30
+
31
+ step_size : None or float
32
+ If step_size is provided (as positive float), it will be used as initial
33
+ step size. If None, the step size will be estimated by the sampler.
34
+
35
+ opt_acc_rate : float
36
+ The optimal acceptance rate to reach if using adaptive step size.
37
+ Suggested values are 0.6 (default) or 0.8 (as in stan). In principle,
38
+ opt_acc_rate should be in (0, 1), however, choosing a value that is very
39
+ close to 1 or 0 might lead to poor performance of the sampler.
40
+
41
+ callback : callable, *Optional*
42
+ If set this function will be called after every sample.
43
+ The signature of the callback function is
44
+ `callback(sample, sample_index)`,
45
+ where `sample` is the current sample and `sample_index` is the index of
46
+ the sample.
47
+ An example is shown in demos/demo31_callback.py.
48
+
49
+ Example
50
+ -------
51
+ .. code-block:: python
52
+
53
+ # Import cuqi
54
+ import cuqi
55
+
56
+ # Define a target distribution
57
+ tp = cuqi.testproblem.WangCubic()
58
+ target = tp.posterior
59
+
60
+ # Set up sampler
61
+ sampler = cuqi.experimental.mcmc.NUTSNew(target)
62
+
63
+ # Sample
64
+ sampler.warmup(5000)
65
+ sampler.sample(10000)
66
+
67
+ # Get samples
68
+ samples = sampler.get_samples()
69
+
70
+ # Plot samples
71
+ samples.plot_pair()
72
+
73
+ After running the NUTS sampler, run diagnostics can be accessed via the
74
+ following attributes:
75
+
76
+ .. code-block:: python
77
+
78
+ # Number of tree nodes created each NUTS iteration
79
+ sampler.num_tree_node_list
80
+
81
+ # Step size used in each NUTS iteration
82
+ sampler.epsilon_list
83
+
84
+ # Suggested step size during adaptation (the value of this step size is
85
+ # only used after adaptation).
86
+ sampler.epsilon_bar_list
87
+
88
+ """
89
+
90
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'_epsilon', '_epsilon_bar',
91
+ '_H_bar', '_mu',
92
+ '_alpha', '_n_alpha'})
93
+
94
+ _HISTORY_KEYS = SamplerNew._HISTORY_KEYS.union({'num_tree_node_list',
95
+ 'epsilon_list',
96
+ 'epsilon_bar_list'})
97
+
98
+ def __init__(self, target, initial_point=None, max_depth=15,
99
+ step_size=None, opt_acc_rate=0.6, **kwargs):
100
+ super().__init__(target, initial_point=initial_point, **kwargs)
101
+
102
+ # Assign parameters as attributes
103
+ self.max_depth = max_depth
104
+ self.step_size = step_size
105
+ self.opt_acc_rate = opt_acc_rate
106
+
107
+ # Set current point
108
+ self.current_point = self.initial_point
109
+
110
+ # Initialize epsilon and epsilon_bar
111
+ # epsilon is the step size used in the current iteration
112
+ # after warm up and one sampling step, epsilon is updated
113
+ # to epsilon_bar for the remaining sampling steps.
114
+ self._epsilon = None
115
+ self._epsilon_bar = None
116
+ self._H_bar = None
117
+
118
+ # Arrays to store acceptance rate
119
+ self._acc = [None]
120
+
121
+ # NUTS run diagnostic:
122
+ # number of tree nodes created each NUTS iteration
123
+ self._num_tree_node = 0
124
+ # Create lists to store NUTS run diagnostics
125
+ self._create_run_diagnostic_attributes()
126
+
127
+ #=========================================================================
128
+ #============================== Properties ===============================
129
+ #=========================================================================
130
+ @property
131
+ def max_depth(self):
132
+ return self._max_depth
133
+
134
+ @max_depth.setter
135
+ def max_depth(self, value):
136
+ if not isinstance(value, int):
137
+ raise TypeError('max_depth must be an integer.')
138
+ if value < 0:
139
+ raise ValueError('max_depth must be >= 0.')
140
+ self._max_depth = value
141
+
142
+ @property
143
+ def step_size(self):
144
+ return self._step_size
145
+
146
+ @step_size.setter
147
+ def step_size(self, value):
148
+ if value is None:
149
+ pass # NUTS will adapt the step size
150
+
151
+ # step_size must be a positive float, raise error otherwise
152
+ elif isinstance(value, bool)\
153
+ or not isinstance(value, Number)\
154
+ or value <= 0:
155
+ raise TypeError('step_size must be a positive float or None.')
156
+ self._step_size = value
157
+
158
+ @property
159
+ def opt_acc_rate(self):
160
+ return self._opt_acc_rate
161
+
162
+ @opt_acc_rate.setter
163
+ def opt_acc_rate(self, value):
164
+ if not isinstance(value, Number) or value <= 0 or value >= 1:
165
+ raise ValueError('opt_acc_rate must be a float in (0, 1).')
166
+ self._opt_acc_rate = value
167
+
168
+ #=========================================================================
169
+ #================== Implement methods required by SamplerNew =============
170
+ #=========================================================================
171
+ def validate_target(self):
172
+ # Check if the target has logd and gradient methods
173
+ try:
174
+ current_target_logd, current_target_grad =\
175
+ self._nuts_target(np.ones(self.dim))
176
+ except:
177
+ raise ValueError('Target must have logd and gradient methods.')
178
+
179
+ def reset(self):
180
+ # Call the parent reset method
181
+ super().reset()
182
+ # Reset NUTS run diagnostic attributes
183
+ self._reset_run_diagnostic_attributes()
184
+
185
+ def step(self):
186
+ # Convert current_point, logd, and grad to numpy arrays
187
+ # if they are CUQIarray objects
188
+ if isinstance(self.current_point, CUQIarray):
189
+ self.current_point = self.current_point.to_numpy()
190
+ if isinstance(self.current_target_logd, CUQIarray):
191
+ self.current_target_logd = self.current_target_logd.to_numpy()
192
+ if isinstance(self.current_target_grad, CUQIarray):
193
+ self.current_target_grad = self.current_target_grad.to_numpy()
194
+
195
+ # reset number of tree nodes for each iteration
196
+ self._num_tree_node = 0
197
+
198
+ # copy current point, logd, and grad in local variables
199
+ point_k = self.current_point.copy() # initial position (parameters)
200
+ logd_k = self.current_target_logd
201
+ grad_k = self.current_target_grad.copy() # initial gradient
202
+
203
+ # compute r_k and Hamiltonian
204
+ r_k = self._Kfun(1, 'sample') # resample momentum vector
205
+ Ham = logd_k - self._Kfun(r_k, 'eval') # Hamiltonian
206
+
207
+ # slice variable
208
+ log_u = Ham - np.random.exponential(1, size=1)
209
+
210
+ # initialization
211
+ j, s, n = 0, 1, 1
212
+ point_minus, point_plus = np.copy(point_k), np.copy(point_k)
213
+ grad_minus, grad_plus = np.copy(grad_k), np.copy(grad_k)
214
+ r_minus, r_plus = np.copy(r_k), np.copy(r_k)
215
+
216
+ # run NUTS
217
+ while (s == 1) and (j <= self.max_depth):
218
+ # sample a direction
219
+ v = int(2*(np.random.rand() < 0.5)-1)
220
+
221
+ # build tree: doubling procedure
222
+ if (v == -1):
223
+ point_minus, r_minus, grad_minus, _, _, _, \
224
+ point_prime, logd_prime, grad_prime,\
225
+ n_prime, s_prime, alpha, n_alpha = \
226
+ self._BuildTree(point_minus, r_minus, grad_minus,
227
+ Ham, log_u, v, j, self._epsilon)
228
+ else:
229
+ _, _, _, point_plus, r_plus, grad_plus, \
230
+ point_prime, logd_prime, grad_prime,\
231
+ n_prime, s_prime, alpha, n_alpha = \
232
+ self._BuildTree(point_plus, r_plus, grad_plus,
233
+ Ham, log_u, v, j, self._epsilon)
234
+
235
+ # Metropolis step
236
+ alpha2 = min(1, (n_prime/n)) #min(0, np.log(n_p) - np.log(n))
237
+ if (s_prime == 1) and (np.random.rand() <= alpha2):
238
+ self.current_point = point_prime
239
+ self.current_target_logd = logd_prime
240
+ self.current_target_grad = np.copy(grad_prime)
241
+ self._acc.append(1)
242
+ else:
243
+ self._acc.append(0)
244
+
245
+ # update number of particles, tree level, and stopping criterion
246
+ n += n_prime
247
+ dpoints = point_plus - point_minus
248
+ s = s_prime *\
249
+ int((dpoints @ r_minus.T) >= 0) * int((dpoints @ r_plus.T) >= 0)
250
+ j += 1
251
+ self._alpha = alpha
252
+ self._n_alpha = n_alpha
253
+
254
+ # update run diagnostic attributes
255
+ self._update_run_diagnostic_attributes(
256
+ self._num_tree_node, self._epsilon, self._epsilon_bar)
257
+
258
+ self._epsilon = self._epsilon_bar
259
+ if np.isnan(self.current_target_logd):
260
+ raise NameError('NaN potential func')
261
+
262
+ def tune(self, skip_len, update_count):
263
+ """ adapt epsilon during burn-in using dual averaging"""
264
+ k = update_count+1
265
+
266
+ # Fixed parameters that do not change during the run
267
+ gamma, t_0, kappa = 0.05, 10, 0.75 # kappa in (0.5, 1]
268
+
269
+ eta1 = 1/(k + t_0)
270
+ self._H_bar = (1-eta1)*self._H_bar +\
271
+ eta1*(self.opt_acc_rate - (self._alpha/self._n_alpha))
272
+ self._epsilon = np.exp(self._mu - (np.sqrt(k)/gamma)*self._H_bar)
273
+ eta = k**(-kappa)
274
+ self._epsilon_bar =\
275
+ np.exp(eta*np.log(self._epsilon) +(1-eta)*np.log(self._epsilon_bar))
276
+
277
+ def _pre_warmup(self):
278
+ super()._pre_warmup()
279
+
280
+ self.current_target_logd, self.current_target_grad =\
281
+ self._nuts_target(self.current_point)
282
+
283
+ # Set up tuning parameters (only first time tuning is called)
284
+ # Note:
285
+ # Parameters changes during the tune run
286
+ # self._epsilon_bar
287
+ # self._H_bar
288
+ # self._epsilon
289
+ # Parameters that does not change during the run
290
+ # self._mu
291
+
292
+ if self._epsilon is None:
293
+ # parameters dual averaging
294
+ self._epsilon = self._FindGoodEpsilon()
295
+ # Parameter mu, does not change during the run
296
+ self._mu = np.log(10*self._epsilon)
297
+
298
+ if self._epsilon_bar is None: # Initial value of epsilon_bar
299
+ self._epsilon_bar = 1
300
+
301
+ if self._H_bar is None: # Initial value of H_bar
302
+ self._H_bar = 0
303
+
304
+ def _pre_sample(self):
305
+ super()._pre_sample()
306
+
307
+ self.current_target_logd, self.current_target_grad =\
308
+ self._nuts_target(self.current_point)
309
+
310
+ # Set up epsilon and epsilon_bar if not set
311
+ if self._epsilon is None:
312
+ if self.step_size is None:
313
+ step_size = self._FindGoodEpsilon()
314
+ else:
315
+ step_size = self.step_size
316
+ self._epsilon = step_size
317
+ self._epsilon_bar = step_size
318
+
319
+ #=========================================================================
320
+ def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
321
+ return self.target.logd(x), self.target.gradient(x)
322
+
323
+ #=========================================================================
324
+ # auxiliary standard Gaussian PDF: kinetic energy function
325
+ # d_log_2pi = d*np.log(2*np.pi)
326
+ def _Kfun(self, r, flag):
327
+ if flag == 'eval': # evaluate
328
+ return 0.5*(r.T @ r) #+ d_log_2pi
329
+ if flag == 'sample': # sample
330
+ return np.random.standard_normal(size=self.dim)
331
+
332
+ #=========================================================================
333
+ def _FindGoodEpsilon(self, epsilon=1):
334
+ point_k = self.current_point
335
+ self.current_target_logd, self.current_target_grad = self._nuts_target(
336
+ point_k)
337
+ logd = self.current_target_logd
338
+ grad = self.current_target_grad
339
+
340
+ r = self._Kfun(1, 'sample') # resample a momentum
341
+ Ham = logd - self._Kfun(r, 'eval') # initial Hamiltonian
342
+ _, r_prime, logd_prime, grad_prime = self._Leapfrog(
343
+ point_k, r, grad, epsilon)
344
+
345
+ # trick to make sure the step is not huge, leading to infinite values of
346
+ # the likelihood
347
+ k = 1
348
+ while np.isinf(logd_prime) or np.isinf(grad_prime).any():
349
+ k *= 0.5
350
+ _, r_prime, logd_prime, grad_prime = self._Leapfrog(
351
+ point_k, r, grad, epsilon*k)
352
+ epsilon = 0.5*k*epsilon
353
+
354
+ # doubles/halves the value of epsilon until the accprob of the Langevin
355
+ # proposal crosses 0.5
356
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
357
+ log_ratio = Ham_prime - Ham
358
+ a = 1 if log_ratio > np.log(0.5) else -1
359
+ while (a*log_ratio > -a*np.log(2)):
360
+ epsilon = (2**a)*epsilon
361
+ _, r_prime, logd_prime, _ = self._Leapfrog(
362
+ point_k, r, grad, epsilon)
363
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval')
364
+ log_ratio = Ham_prime - Ham
365
+ return epsilon
366
+
367
+ #=========================================================================
368
+ def _Leapfrog(self, point_old, r_old, grad_old, epsilon):
369
+ # symplectic integrator: trajectories preserve phase space volumen
370
+ r_new = r_old + 0.5*epsilon*grad_old # half-step
371
+ point_new = point_old + epsilon*r_new # full-step
372
+ logd_new, grad_new = self._nuts_target(point_new) # new gradient
373
+ r_new += 0.5*epsilon*grad_new # half-step
374
+ return point_new, r_new, logd_new, grad_new
375
+
376
+ #=========================================================================
377
+ def _BuildTree(
378
+ self, point_k, r, grad, Ham, log_u, v, j, epsilon, Delta_max=1000):
379
+ # Increment the number of tree nodes counter
380
+ self._num_tree_node += 1
381
+
382
+ if (j == 0): # base case
383
+ # single leapfrog step in the direction v
384
+ point_prime, r_prime, logd_prime, grad_prime = self._Leapfrog(
385
+ point_k, r, grad, v*epsilon)
386
+ Ham_prime = logd_prime - self._Kfun(r_prime, 'eval') # Hamiltonian
387
+ # eval
388
+ n_prime = int(log_u <= Ham_prime) # if particle is in the slice
389
+ s_prime = int(log_u < Delta_max + Ham_prime) # check U-turn
390
+ #
391
+ diff_Ham = Ham_prime - Ham
392
+
393
+ # Compute the acceptance probability
394
+ # alpha_prime = min(1, np.exp(diff_Ham))
395
+ # written in a stable way to avoid overflow when computing
396
+ # exp(diff_Ham) for large values of diff_Ham
397
+ alpha_prime = 1 if diff_Ham > 0 else np.exp(diff_Ham)
398
+ n_alpha_prime = 1
399
+ #
400
+ point_minus, point_plus = point_prime, point_prime
401
+ r_minus, r_plus = r_prime, r_prime
402
+ grad_minus, grad_plus = grad_prime, grad_prime
403
+ else:
404
+ # recursion: build the left/right subtrees
405
+ point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus, \
406
+ point_prime, logd_prime, grad_prime,\
407
+ n_prime, s_prime, alpha_prime, n_alpha_prime = \
408
+ self._BuildTree(point_k, r, grad,
409
+ Ham, log_u, v, j-1, epsilon)
410
+ if (s_prime == 1): # do only if the stopping criteria does not
411
+ # verify at the first subtree
412
+ if (v == -1):
413
+ point_minus, r_minus, grad_minus, _, _, _, \
414
+ point_2prime, logd_2prime, grad_2prime,\
415
+ n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
416
+ self._BuildTree(point_minus, r_minus, grad_minus,
417
+ Ham, log_u, v, j-1, epsilon)
418
+ else:
419
+ _, _, _, point_plus, r_plus, grad_plus, \
420
+ point_2prime, logd_2prime, grad_2prime,\
421
+ n_2prime, s_2prime, alpha_2prime, n_alpha_2prime = \
422
+ self._BuildTree(point_plus, r_plus, grad_plus,
423
+ Ham, log_u, v, j-1, epsilon)
424
+
425
+ # Metropolis step
426
+ alpha2 = n_2prime / max(1, (n_prime + n_2prime))
427
+ if (np.random.rand() <= alpha2):
428
+ point_prime = np.copy(point_2prime)
429
+ logd_prime = np.copy(logd_2prime)
430
+ grad_prime = np.copy(grad_2prime)
431
+
432
+ # update number of particles and stopping criterion
433
+ alpha_prime += alpha_2prime
434
+ n_alpha_prime += n_alpha_2prime
435
+ dpoints = point_plus - point_minus
436
+ s_prime = s_2prime *\
437
+ int((dpoints@r_minus.T)>=0) * int((dpoints@r_plus.T)>=0)
438
+ n_prime += n_2prime
439
+
440
+ return point_minus, r_minus, grad_minus, point_plus, r_plus, grad_plus,\
441
+ point_prime, logd_prime, grad_prime,\
442
+ n_prime, s_prime, alpha_prime, n_alpha_prime
443
+
444
+ #=========================================================================
445
+ #======================== Diagnostic methods =============================
446
+ #=========================================================================
447
+
448
+ def _create_run_diagnostic_attributes(self):
449
+ """A method to create attributes to store NUTS run diagnostic."""
450
+ self._reset_run_diagnostic_attributes()
451
+
452
+ def _reset_run_diagnostic_attributes(self):
453
+ """A method to reset attributes to store NUTS run diagnostic."""
454
+ # List to store number of tree nodes created each NUTS iteration
455
+ self.num_tree_node_list = []
456
+ # List of step size used in each NUTS iteration
457
+ self.epsilon_list = []
458
+ # List of burn-in step size suggestion during adaptation
459
+ # only used when adaptation is done
460
+ # remains fixed after adaptation (after burn-in)
461
+ self.epsilon_bar_list = []
462
+
463
+ def _update_run_diagnostic_attributes(self, n_tree, eps, eps_bar):
464
+ """A method to update attributes to store NUTS run diagnostic."""
465
+ # Store the number of tree nodes created in iteration k
466
+ self.num_tree_node_list.append(n_tree)
467
+ # Store the step size used in iteration k
468
+ self.epsilon_list.append(eps)
469
+ # Store the step size suggestion during adaptation in iteration k
470
+ self.epsilon_bar_list.append(eps_bar)
@@ -60,14 +60,17 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
60
60
  A Deblur example can be found in demos/demo27_ULA.py
61
61
  # TODO: update demo once sampler merged
62
62
  """
63
+
64
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale', 'current_target_grad'})
65
+
63
66
  def __init__(self, target, scale=1.0, **kwargs):
64
67
 
65
68
  super().__init__(target, **kwargs)
66
69
 
67
70
  self.scale = scale
68
71
  self.current_point = self.initial_point
69
- self.current_target_eval = self.target.logd(self.current_point)
70
- self.current_target_grad_eval = self.target.gradient(self.current_point)
72
+ self.current_target_logd = self.target.logd(self.current_point)
73
+ self.current_target_grad = self.target.gradient(self.current_point)
71
74
  self._acc = [1] # TODO. Check if we need this
72
75
 
73
76
  def validate_target(self):
@@ -99,15 +102,15 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
99
102
  1 (accepted)
100
103
  """
101
104
  self.current_point = x_star
102
- self.current_target_eval = target_eval_star
103
- self.current_target_grad_eval = target_grad_star
105
+ self.current_target_logd = target_eval_star
106
+ self.current_target_grad = target_grad_star
104
107
  acc = 1
105
108
  return acc
106
109
 
107
110
  def step(self):
108
111
  # propose state
109
112
  xi = cuqi.distribution.Normal(mean=np.zeros(self.dim), std=np.sqrt(self.scale)).sample()
110
- x_star = self.current_point + 0.5*self.scale*self.current_target_grad_eval + xi
113
+ x_star = self.current_point + 0.5*self.scale*self.current_target_grad + xi
111
114
 
112
115
  # evaluate target
113
116
  target_eval_star, target_grad_star = self.target.logd(x_star), self.target.gradient(x_star)
@@ -120,26 +123,6 @@ class ULANew(SamplerNew): # Refactor to Proposal-based sampler?
120
123
  def tune(self, skip_len, update_count):
121
124
  pass
122
125
 
123
- def get_state(self):
124
- if isinstance(self.current_point, CUQIarray):
125
- self.current_point = self.current_point.to_numpy()
126
- if isinstance(self.current_target_eval, CUQIarray):
127
- self.current_target_eval = self.current_target_eval.to_numpy()
128
- if isinstance(self.current_target_grad_eval, CUQIarray):
129
- self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
130
- return {'sampler_type': 'ULA', 'current_point': self.current_point, \
131
- 'current_target_eval': self.current_target_eval, \
132
- 'current_target_grad_eval': self.current_target_grad_eval, \
133
- 'scale': self.scale}
134
-
135
- def set_state(self, state):
136
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
137
- self.current_point = temp
138
- temp = CUQIarray(state['current_target_eval'] , geometry=self.target.geometry)
139
- self.current_target_eval = temp
140
- temp = CUQIarray(state['current_target_grad_eval'] , geometry=self.target.geometry)
141
- self.current_target_grad_eval = temp
142
- self.scale = state['scale']
143
126
 
144
127
  class MALANew(ULANew): # Refactor to Proposal-based sampler?
145
128
  """ Metropolis-adjusted Langevin algorithm (MALA) (Roberts and Tweedie, 1996)
@@ -219,9 +202,9 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
219
202
  scaler
220
203
  1 if accepted, 0 otherwise
221
204
  """
222
- log_target_ratio = target_eval_star - self.current_target_eval
205
+ log_target_ratio = target_eval_star - self.current_target_logd
223
206
  log_prop_ratio = self._log_proposal(self.current_point, x_star, target_grad_star) \
224
- - self._log_proposal(x_star, self.current_point, self.current_target_grad_eval)
207
+ - self._log_proposal(x_star, self.current_point, self.current_target_grad)
225
208
  log_alpha = min(0, log_target_ratio + log_prop_ratio)
226
209
 
227
210
  # accept/reject with Metropolis
@@ -229,8 +212,8 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
229
212
  log_u = np.log(np.random.rand())
230
213
  if (log_u <= log_alpha) and (np.isnan(target_eval_star) == False):
231
214
  self.current_point = x_star
232
- self.current_target_eval = target_eval_star
233
- self.current_target_grad_eval = target_grad_star
215
+ self.current_target_logd = target_eval_star
216
+ self.current_target_grad = target_grad_star
234
217
  acc = 1
235
218
  return acc
236
219
 
@@ -241,15 +224,3 @@ class MALANew(ULANew): # Refactor to Proposal-based sampler?
241
224
  mu = theta_k + ((self.scale)/2)*g_logpi_k
242
225
  misfit = theta_star - mu
243
226
  return -0.5*((1/(self.scale))*(misfit.T @ misfit))
244
-
245
- def get_state(self):
246
- if isinstance(self.current_point, CUQIarray):
247
- self.current_point = self.current_point.to_numpy()
248
- if isinstance(self.current_target_eval, CUQIarray):
249
- self.current_target_eval = self.current_target_eval.to_numpy()
250
- if isinstance(self.current_target_grad_eval, CUQIarray):
251
- self.current_target_grad_eval = self.current_target_grad_eval.to_numpy()
252
- return {'sampler_type': 'MALA', 'current_point': self.current_point, \
253
- 'current_target_eval': self.current_target_eval, \
254
- 'current_target_grad_eval': self.current_target_grad_eval, \
255
- 'scale': self.scale}
@@ -23,6 +23,8 @@ class MHNew(ProposalBasedSamplerNew):
23
23
 
24
24
  """
25
25
 
26
+ _STATE_KEYS = ProposalBasedSamplerNew._STATE_KEYS.union({'scale', '_scale_temp'})
27
+
26
28
  def __init__(self, target, proposal=None, scale=1, **kwargs):
27
29
  super().__init__(target, proposal=proposal, scale=scale, **kwargs)
28
30
  # Due to a bug? in old MH, we must keep track of this extra variable to match behavior.
@@ -54,7 +56,7 @@ class MHNew(ProposalBasedSamplerNew):
54
56
  target_eval_star = self.target.logd(x_star)
55
57
 
56
58
  # ratio and acceptance probability
57
- ratio = target_eval_star - self.current_target # proposal is symmetric
59
+ ratio = target_eval_star - self.current_target_logd # proposal is symmetric
58
60
  alpha = min(0, ratio)
59
61
 
60
62
  # accept/reject
@@ -62,7 +64,7 @@ class MHNew(ProposalBasedSamplerNew):
62
64
  acc = 0
63
65
  if (u_theta <= alpha):
64
66
  self.current_point = x_star
65
- self.current_target = target_eval_star
67
+ self.current_target_logd = target_eval_star
66
68
  acc = 1
67
69
 
68
70
  return acc
@@ -79,13 +81,3 @@ class MHNew(ProposalBasedSamplerNew):
79
81
 
80
82
  # update parameters
81
83
  self.scale = min(self._scale_temp, 1)
82
-
83
- def get_state(self):
84
- return {'sampler_type': 'MH', 'current_point': self.current_point.to_numpy(), 'current_target': self.current_target.to_numpy(), 'scale': self.scale}
85
-
86
- def set_state(self, state):
87
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
88
- self.current_point = temp
89
- temp = CUQIarray(state['current_target'] , geometry=self.target.geometry)
90
- self.current_target = temp
91
- self.scale = state['scale']
@@ -5,13 +5,15 @@ from cuqi.array import CUQIarray
5
5
 
6
6
  class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
7
7
 
8
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'scale', 'current_likelihood_logd'})
9
+
8
10
  def __init__(self, target, scale=1.0, **kwargs):
9
11
 
10
12
  super().__init__(target, **kwargs)
11
13
 
12
14
  self.scale = scale
13
15
  self.current_point = self.initial_point
14
- self.current_loglike_eval = self._loglikelihood(self.current_point)
16
+ self.current_likelihood_logd = self._loglikelihood(self.current_point)
15
17
 
16
18
  self._acc = [1] # TODO. Check if we need this
17
19
 
@@ -33,7 +35,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
33
35
  loglike_eval_star = self._loglikelihood(x_star)
34
36
 
35
37
  # ratio and acceptance probability
36
- ratio = loglike_eval_star - self.current_loglike_eval # proposal is symmetric
38
+ ratio = loglike_eval_star - self.current_likelihood_logd # proposal is symmetric
37
39
  alpha = min(0, ratio)
38
40
 
39
41
  # accept/reject
@@ -41,7 +43,7 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
41
43
  u_theta = np.log(np.random.rand())
42
44
  if (u_theta <= alpha):
43
45
  self.current_point = x_star
44
- self.current_loglike_eval = loglike_eval_star
46
+ self.current_likelihood_logd = loglike_eval_star
45
47
  acc = 1
46
48
 
47
49
  return acc
@@ -87,15 +89,3 @@ class pCNNew(SamplerNew): # Refactor to Proposal-based sampler?
87
89
 
88
90
  def tune(self, skip_len, update_count):
89
91
  pass
90
-
91
- def get_state(self):
92
- return {'sampler_type': 'PCN', 'current_point': self.current_point.to_numpy(), \
93
- 'current_loglike_eval': self.current_loglike_eval.to_numpy(), \
94
- 'scale': self.scale}
95
-
96
- def set_state(self, state):
97
- temp = CUQIarray(state['current_point'] , geometry=self.target.geometry)
98
- self.current_point = temp
99
- temp = CUQIarray(state['current_loglike_eval'] , geometry=self.target.geometry)
100
- self.current_loglike_eval = temp
101
- self.scale = state['scale']
@@ -50,6 +50,7 @@ class LinearRTONew(SamplerNew):
50
50
 
51
51
  if initial_point is None: #TODO: Replace later with a getter
52
52
  self.initial_point = np.zeros(self.dim)
53
+ self._samples = [self.initial_point]
53
54
 
54
55
  self.current_point = self.initial_point
55
56
  self._acc = [1] # TODO. Check if we need this
@@ -188,12 +189,6 @@ class LinearRTONew(SamplerNew):
188
189
  if not hasattr(self.prior, "sqrtprecTimesMean"):
189
190
  raise TypeError("Prior must contain a sqrtprecTimesMean attribute")
190
191
 
191
- def get_state(self): #TODO: LinearRTO only need initial_point for reproducibility?
192
- return {'sampler_type': 'LinearRTO'}
193
-
194
- def set_state(self, state): #TODO: LinearRTO only need initial_point for reproducibility?
195
- pass
196
-
197
192
  class RegularizedLinearRTONew(LinearRTONew):
198
193
  """
199
194
  Regularized Linear RTO (Randomize-Then-Optimize) sampler.
@@ -272,4 +267,4 @@ class RegularizedLinearRTONew(LinearRTONew):
272
267
  maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
273
268
  self.current_point, _ = sim.solve()
274
269
  acc = 1
275
- return acc
270
+ return acc
@@ -21,6 +21,11 @@ class SamplerNew(ABC):
21
21
  Samples are stored in a list to allow for dynamic growth of the sample set. Returning samples is done by creating a new Samples object from the list of samples.
22
22
 
23
23
  """
24
+ _STATE_KEYS = {'current_point'}
25
+ """ Set of keys for the state dictionary. """
26
+
27
+ _HISTORY_KEYS = {'_samples', '_acc'}
28
+ """ Set of keys for the history dictionary. """
24
29
 
25
30
  def __init__(self, target: cuqi.density.Density, initial_point=None, callback=None):
26
31
  """ Initializer for abstract base class for all samplers.
@@ -45,8 +50,10 @@ class SamplerNew(ABC):
45
50
  # Choose initial point if not given
46
51
  if initial_point is None:
47
52
  initial_point = np.ones(self.dim)
53
+
54
+ self.initial_point = initial_point
48
55
 
49
- self._samples = [initial_point]
56
+ self._samples = [initial_point] # Remove. See #324.
50
57
 
51
58
  # ------------ Abstract methods to be implemented by subclasses ------------
52
59
 
@@ -65,29 +72,16 @@ class SamplerNew(ABC):
65
72
  """ Validate the target is compatible with the sampler. Called when the target is set. Should raise an error if the target is not compatible. """
66
73
  pass
67
74
 
68
- @abstractmethod
69
- def get_state(self):
70
- """ Return the state of the sampler. """
75
+ # -- _pre_sample and _pre_warmup methods: can be overridden by subclasses --
76
+ def _pre_sample(self):
77
+ """ Any code that needs to be run before sampling. """
71
78
  pass
72
79
 
73
- @abstractmethod
74
- def set_state(self, state):
75
- """ Set the state of the sampler. """
80
+ def _pre_warmup(self):
81
+ """ Any code that needs to be run before warmup. """
76
82
  pass
77
83
 
78
-
79
84
  # ------------ Public attributes ------------
80
-
81
- @property
82
- def initial_point(self):
83
- """ Return the initial point of the sampler. This is always the first sample. """
84
- return self._samples[0]
85
-
86
- @initial_point.setter
87
- def initial_point(self, value):
88
- """ Set the initial point of the sampler. """
89
- self._samples[0] = value
90
-
91
85
  @property
92
86
  def dim(self):
93
87
  """ Dimension of the target density. """
@@ -109,6 +103,15 @@ class SamplerNew(ABC):
109
103
  self._target = value
110
104
  self.validate_target()
111
105
 
106
+ @property
107
+ def current_point(self):
108
+ """ The current point of the sampler. """
109
+ return self._current_point
110
+
111
+ @current_point.setter
112
+ def current_point(self, value):
113
+ """ Set the current point of the sampler. """
114
+ self._current_point = value
112
115
 
113
116
  # ------------ Public methods ------------
114
117
 
@@ -156,6 +159,9 @@ class SamplerNew(ABC):
156
159
  if batch_size > 0:
157
160
  batch_handler = _BatchHandler(batch_size, sample_path)
158
161
 
162
+ # Any code that needs to be run before sampling
163
+ self._pre_sample()
164
+
159
165
  # Draw samples
160
166
  for _ in progressbar( range(Ns) ):
161
167
 
@@ -191,6 +197,9 @@ class SamplerNew(ABC):
191
197
 
192
198
  tune_interval = max(int(tune_freq * Nb), 1)
193
199
 
200
+ # Any code that needs to be run before warmup
201
+ self._pre_warmup()
202
+
194
203
  # Draw warmup samples with tuning
195
204
  for idx in progressbar(range(Nb)):
196
205
 
@@ -209,6 +218,94 @@ class SamplerNew(ABC):
209
218
  self._call_callback(self.current_point, len(self._samples)-1)
210
219
 
211
220
  return self
221
+
222
+ def get_state(self) -> dict:
223
+ """ Return the state of the sampler.
224
+
225
+ The state is used when checkpointing the sampler.
226
+
227
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
228
+ The 'metadata' key contains information about the sampler type.
229
+ The 'state' key contains the state of the sampler.
230
+
231
+ For example, the state of a "MH" sampler could be:
232
+
233
+ state = {
234
+ 'metadata': {
235
+ 'sampler_type': 'MH'
236
+ },
237
+ 'state': {
238
+ 'current_point': np.array([...]),
239
+ 'current_target_logd': -123.45,
240
+ 'scale': 1.0,
241
+ ...
242
+ }
243
+ }
244
+ """
245
+ state = {
246
+ 'metadata': {
247
+ 'sampler_type': self.__class__.__name__
248
+ },
249
+ 'state': {
250
+ key: getattr(self, key) for key in self._STATE_KEYS
251
+ }
252
+ }
253
+ return state
254
+
255
+ def set_state(self, state: dict):
256
+ """ Set the state of the sampler.
257
+
258
+ The state is used when loading the sampler from a checkpoint.
259
+
260
+ The state of the sampler is a dictionary with keys 'metadata' and 'state'.
261
+
262
+ For example, the state of a "MH" sampler could be:
263
+
264
+ state = {
265
+ 'metadata': {
266
+ 'sampler_type': 'MH'
267
+ },
268
+ 'state': {
269
+ 'current_point': np.array([...]),
270
+ 'current_target_logd': -123.45,
271
+ 'scale': 1.0,
272
+ ...
273
+ }
274
+ }
275
+ """
276
+ if state['metadata']['sampler_type'] != self.__class__.__name__:
277
+ raise ValueError(f"Sampler type in state dictionary ({state['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
278
+
279
+ for key, value in state['state'].items():
280
+ if key in self._STATE_KEYS:
281
+ setattr(self, key, value)
282
+ else:
283
+ raise ValueError(f"Key {key} not recognized in state dictionary of sampler {self.__class__.__name__}.")
284
+
285
+ def get_history(self) -> dict:
286
+ """ Return the history of the sampler. """
287
+ history = {
288
+ 'metadata': {
289
+ 'sampler_type': self.__class__.__name__
290
+ },
291
+ 'history': {
292
+ key: getattr(self, key) for key in self._HISTORY_KEYS
293
+ }
294
+ }
295
+ return history
296
+
297
+ def set_history(self, history: dict):
298
+ """ Set the history of the sampler. """
299
+ if history['metadata']['sampler_type'] != self.__class__.__name__:
300
+ raise ValueError(f"Sampler type in history dictionary ({history['metadata']['sampler_type']}) does not match the type of the sampler ({self.__class__.__name__}).")
301
+
302
+ for key, value in history['history'].items():
303
+ if key in self._HISTORY_KEYS:
304
+ setattr(self, key, value)
305
+ else:
306
+ raise ValueError(f"Key {key} not recognized in history dictionary of sampler {self.__class__.__name__}.")
307
+
308
+ # ------------ Private methods ------------
212
309
 
213
310
  def _call_callback(self, sample, sample_index):
214
311
  """ Calls the callback function. Assumes input is sample and sample index"""
@@ -218,6 +315,9 @@ class SamplerNew(ABC):
218
315
 
219
316
  class ProposalBasedSamplerNew(SamplerNew, ABC):
220
317
  """ Abstract base class for samplers that use a proposal distribution. """
318
+
319
+ _STATE_KEYS = SamplerNew._STATE_KEYS.union({'current_target_logd', 'scale'})
320
+
221
321
  def __init__(self, target, proposal=None, scale=1, **kwargs):
222
322
  """ Initializer for proposal based samplers.
223
323
 
@@ -240,7 +340,7 @@ class ProposalBasedSamplerNew(SamplerNew, ABC):
240
340
  super().__init__(target, **kwargs)
241
341
 
242
342
  self.current_point = self.initial_point
243
- self.current_target = self.target.logd(self.current_point)
343
+ self.current_target_logd = self.target.logd(self.current_point)
244
344
  self.proposal = proposal
245
345
  self.scale = scale
246
346
 
cuqi/sampler/_hmc.py CHANGED
@@ -83,6 +83,9 @@ class NUTS(Sampler):
83
83
  self.max_depth = max_depth
84
84
  self.adapt_step_size = adapt_step_size
85
85
  self.opt_acc_rate = opt_acc_rate
86
+ # if this flag is True, the samples and the burn-in will be returned
87
+ # otherwise, the burn-in will be truncated
88
+ self._return_burnin = False
86
89
 
87
90
  # NUTS run diagnostic
88
91
  # number of tree nodes created each NUTS iteration
@@ -226,9 +229,10 @@ class NUTS(Sampler):
226
229
  if np.isnan(joint_eval[k]):
227
230
  raise NameError('NaN potential func')
228
231
 
229
- # apply burn-in
230
- theta = theta[:, Nb:]
231
- joint_eval = joint_eval[Nb:]
232
+ # apply burn-in
233
+ if not self._return_burnin:
234
+ theta = theta[:, Nb:]
235
+ joint_eval = joint_eval[Nb:]
232
236
  return theta, joint_eval, step_sizes
233
237
 
234
238
  #=========================================================================