CUQIpy 1.2.0.post0.dev30__py3-none-any.whl → 1.2.0.post0.dev90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of CUQIpy might be problematic. Click here for more details.
- {CUQIpy-1.2.0.post0.dev30.dist-info → CUQIpy-1.2.0.post0.dev90.dist-info}/METADATA +1 -1
- {CUQIpy-1.2.0.post0.dev30.dist-info → CUQIpy-1.2.0.post0.dev90.dist-info}/RECORD +12 -11
- {CUQIpy-1.2.0.post0.dev30.dist-info → CUQIpy-1.2.0.post0.dev90.dist-info}/WHEEL +1 -1
- cuqi/_version.py +3 -3
- cuqi/distribution/__init__.py +1 -0
- cuqi/distribution/_normal.py +34 -0
- cuqi/distribution/_truncated_normal.py +129 -0
- cuqi/experimental/mcmc/_gibbs.py +1 -15
- cuqi/experimental/mcmc/_hmc.py +28 -32
- cuqi/experimental/mcmc/_sampler.py +1 -7
- {CUQIpy-1.2.0.post0.dev30.dist-info → CUQIpy-1.2.0.post0.dev90.dist-info}/LICENSE +0 -0
- {CUQIpy-1.2.0.post0.dev30.dist-info → CUQIpy-1.2.0.post0.dev90.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: CUQIpy
|
|
3
|
-
Version: 1.2.0.post0.
|
|
3
|
+
Version: 1.2.0.post0.dev90
|
|
4
4
|
Summary: Computational Uncertainty Quantification for Inverse problems in Python
|
|
5
5
|
Maintainer-email: "Nicolai A. B. Riis" <nabr@dtu.dk>, "Jakob S. Jørgensen" <jakj@dtu.dk>, "Amal M. Alghamdi" <amaal@dtu.dk>, Chao Zhang <chaz@dtu.dk>
|
|
6
6
|
License: Apache License
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
cuqi/__init__.py,sha256=LsGilhl-hBLEn6Glt8S_l0OJzAA1sKit_rui8h-D-p0,488
|
|
2
2
|
cuqi/_messages.py,sha256=fzEBrZT2kbmfecBBPm7spVu7yHdxGARQB4QzXhJbCJ0,415
|
|
3
|
-
cuqi/_version.py,sha256=
|
|
3
|
+
cuqi/_version.py,sha256=fadCQ-al0LVIaJsUncww5HgsVEcSF53R-lWs5uar-ow,509
|
|
4
4
|
cuqi/config.py,sha256=wcYvz19wkeKW2EKCGIKJiTpWt5kdaxyt4imyRkvtTRA,526
|
|
5
5
|
cuqi/diagnostics.py,sha256=5OrbJeqpynqRXOe5MtOKKhe7EAVdOEpHIqHnlMW9G_c,3029
|
|
6
6
|
cuqi/array/__init__.py,sha256=-EeiaiWGNsE3twRS4dD814BIlfxEsNkTCZUc5gjOXb0,30
|
|
@@ -14,7 +14,7 @@ cuqi/data/cookie.png,sha256=mr6wUeoIUc5VC2qYj8vafOmTbcRwz0fHz4IIPK9_PnE,984680
|
|
|
14
14
|
cuqi/data/satellite.mat,sha256=a0Nz_Ak-Y0m360dH74pa_rpk-MhaQ91ftGTKhQX7I8g,16373
|
|
15
15
|
cuqi/density/__init__.py,sha256=0zfVcPgqdqiPkss5n_WP_PUt-G3ovHXjokhqEKIlLwA,48
|
|
16
16
|
cuqi/density/_density.py,sha256=BG7gtP0cbFYLVgjYQGkNAhM95PR5ocBVLKRlOVX2PyM,7253
|
|
17
|
-
cuqi/distribution/__init__.py,sha256=
|
|
17
|
+
cuqi/distribution/__init__.py,sha256=Vvw-ge5HAF1now9n4rcwDicCsEUN9_jbbxlKxyzeUuY,761
|
|
18
18
|
cuqi/distribution/_beta.py,sha256=lgN6PGoF9RXQtrMGqSaSBV0hw-LEsOfRTD2Q2L3-Ok4,2903
|
|
19
19
|
cuqi/distribution/_cauchy.py,sha256=UsVXYz8HhagXN5fIWSAIyELqhsJAX_-wk9kkRGgRmA8,3296
|
|
20
20
|
cuqi/distribution/_cmrf.py,sha256=tCbEulM_O7FB3C_W-3IqZp9zGHkTofCdFF0ybHc9UZI,3745
|
|
@@ -29,9 +29,10 @@ cuqi/distribution/_laplace.py,sha256=5exLvlzJm2AgfvZ3KUSkjfwlGwwbsktBxP8z0iLMik8
|
|
|
29
29
|
cuqi/distribution/_lmrf.py,sha256=rdGoQ-fPe1oW6Z29P-l3woq0NX3_RxUQ2rzm1VzemNM,3290
|
|
30
30
|
cuqi/distribution/_lognormal.py,sha256=8_hOFQ3iu88ujX8vxmfVEZ0fdmlhTY98PlG5PasPjEg,2612
|
|
31
31
|
cuqi/distribution/_modifiedhalfnormal.py,sha256=eCg9YhH-zyX25V5WqdBwQykwG_90lm5Qc2901z7jFUE,7390
|
|
32
|
-
cuqi/distribution/_normal.py,sha256=
|
|
32
|
+
cuqi/distribution/_normal.py,sha256=vhIiAseW09IKh1uy0KUq7RP1IuY7hH5aNM1W_R8Gd_Q,2912
|
|
33
33
|
cuqi/distribution/_posterior.py,sha256=zAfL0GECxekZ2lBt1W6_LN0U_xskMwK4VNce5xAF7ig,5018
|
|
34
34
|
cuqi/distribution/_smoothed_laplace.py,sha256=p-1Y23mYA9omwiHGkEuv3T2mwcPAAoNlCr7T8osNkjE,2925
|
|
35
|
+
cuqi/distribution/_truncated_normal.py,sha256=sZkLYgnkGOyS_3ZxY7iw6L62t-Jh6shzsweRsRepN2k,4240
|
|
35
36
|
cuqi/distribution/_uniform.py,sha256=KA8yQ6ZS3nQGS4PYJ4hpDg6Eq8EQKQvPsIpYfR8fj2w,1967
|
|
36
37
|
cuqi/experimental/__init__.py,sha256=vhZvyMX6rl8Y0haqCzGLPz6PSUKyu75XMQbeDHqTTrw,83
|
|
37
38
|
cuqi/experimental/mcmc/__init__.py,sha256=1sn0U6Ep0x5zv2602og2DkV3Bs8hNFOiq7C3VcMimVw,4472
|
|
@@ -39,14 +40,14 @@ cuqi/experimental/mcmc/_conjugate.py,sha256=VNPQkGity0mposcqxrx4UIeXm35EvJvZED4p
|
|
|
39
40
|
cuqi/experimental/mcmc/_conjugate_approx.py,sha256=uEnY2ea9su5ivcNagyRAwpQP2gBY98sXU7N0y5hTADo,3653
|
|
40
41
|
cuqi/experimental/mcmc/_cwmh.py,sha256=50v3uZaWhlVnfrEB5-lB_7pn8QoUVBe-xWxKGKbmNHg,7234
|
|
41
42
|
cuqi/experimental/mcmc/_direct.py,sha256=9pQS_2Qk2-ybt6m8WTfPoKetcxQ00WaTRN85-Z0FrBY,777
|
|
42
|
-
cuqi/experimental/mcmc/_gibbs.py,sha256=
|
|
43
|
-
cuqi/experimental/mcmc/_hmc.py,sha256=
|
|
43
|
+
cuqi/experimental/mcmc/_gibbs.py,sha256=evgxf2tLFLlKB3hN0qz9a9NcZQSES8wdacnn3uNWocQ,12005
|
|
44
|
+
cuqi/experimental/mcmc/_hmc.py,sha256=8p4QxZBRpFLzwamH-DWHSdZE0aXX3FqonBzczz_XkDw,19340
|
|
44
45
|
cuqi/experimental/mcmc/_langevin_algorithm.py,sha256=yNO7ABxmkixzcLG-lv57GOTyeTr7HwFs2DrrhuZW9OI,8398
|
|
45
46
|
cuqi/experimental/mcmc/_laplace_approximation.py,sha256=rdiE3cMQFq6FLQcOQwPpuGIxrTAp3aoGPxMDSdeopV0,5688
|
|
46
47
|
cuqi/experimental/mcmc/_mh.py,sha256=MXo0ahXP4KGFkaY4HtvcBE-TMQzsMlTmLKzSvpz7drU,2941
|
|
47
48
|
cuqi/experimental/mcmc/_pcn.py,sha256=wqJBZLuRFSwxihaI53tumAg6AWVuceLMOmXssTetd1A,3374
|
|
48
49
|
cuqi/experimental/mcmc/_rto.py,sha256=OtzgiYCxDoTdXp7y4mkLa2upj74qadesoqHYpr11ZCg,10061
|
|
49
|
-
cuqi/experimental/mcmc/_sampler.py,sha256=
|
|
50
|
+
cuqi/experimental/mcmc/_sampler.py,sha256=xtoT70T8xe3Ye7yYdIFQD_kivjXlqUImyV3bMt406nk,20106
|
|
50
51
|
cuqi/experimental/mcmc/_utilities.py,sha256=kUzHbhIS3HYZRbneNBK41IogUYX5dS_bJxqEGm7TQBI,525
|
|
51
52
|
cuqi/geometry/__init__.py,sha256=Tz1WGzZBY-QGH3c0GiyKm9XHN8MGGcnU6TUHLZkzB3o,842
|
|
52
53
|
cuqi/geometry/_geometry.py,sha256=SDRZdiN2CIuS591lXxqgFoPWPIpwY-MHk75116QvdYY,46901
|
|
@@ -85,8 +86,8 @@ cuqi/testproblem/_testproblem.py,sha256=x769LwwRdJdzIiZkcQUGb_5-vynNTNALXWKato7s
|
|
|
85
86
|
cuqi/utilities/__init__.py,sha256=H7xpJe2UinjZftKvE2JuXtTi4DqtkR6uIezStAXwfGg,428
|
|
86
87
|
cuqi/utilities/_get_python_variable_name.py,sha256=QwlBVj2koJRA8s8pWd554p7-ElcI7HUwY32HknaR92E,1827
|
|
87
88
|
cuqi/utilities/_utilities.py,sha256=Jc4knn80vLoA7kgw9FzXwKVFGaNBOXiA9kgvltZU3Ao,11777
|
|
88
|
-
CUQIpy-1.2.0.post0.
|
|
89
|
-
CUQIpy-1.2.0.post0.
|
|
90
|
-
CUQIpy-1.2.0.post0.
|
|
91
|
-
CUQIpy-1.2.0.post0.
|
|
92
|
-
CUQIpy-1.2.0.post0.
|
|
89
|
+
CUQIpy-1.2.0.post0.dev90.dist-info/LICENSE,sha256=kJWRPrtRoQoZGXyyvu50Uc91X6_0XRaVfT0YZssicys,10799
|
|
90
|
+
CUQIpy-1.2.0.post0.dev90.dist-info/METADATA,sha256=KBSZdCAb8ZYWIzYvHOZ4iqrog8QGiBynjOw0gbo_sis,18495
|
|
91
|
+
CUQIpy-1.2.0.post0.dev90.dist-info/WHEEL,sha256=P9jw-gEje8ByB7_hXoICnHtVCrEwMQh-630tKvQWehc,91
|
|
92
|
+
CUQIpy-1.2.0.post0.dev90.dist-info/top_level.txt,sha256=AgmgMc6TKfPPqbjV0kvAoCBN334i_Lwwojc7HE3ZwD0,5
|
|
93
|
+
CUQIpy-1.2.0.post0.dev90.dist-info/RECORD,,
|
cuqi/_version.py
CHANGED
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-11-03T22:18:33+0100",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "1.2.0.post0.
|
|
14
|
+
"full-revisionid": "8f8b00804a857370d46fd7bdf26cb9542a6b8f34",
|
|
15
|
+
"version": "1.2.0.post0.dev90"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
cuqi/distribution/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from ._laplace import Laplace
|
|
|
12
12
|
from ._smoothed_laplace import SmoothedLaplace
|
|
13
13
|
from ._lognormal import Lognormal
|
|
14
14
|
from ._normal import Normal
|
|
15
|
+
from ._truncated_normal import TruncatedNormal
|
|
15
16
|
from ._posterior import Posterior
|
|
16
17
|
from ._uniform import Uniform
|
|
17
18
|
from ._custom import UserDefinedDistribution, DistributionGallery
|
cuqi/distribution/_normal.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import numpy as np
|
|
2
|
+
import numbers
|
|
2
3
|
from scipy.special import erf
|
|
4
|
+
from cuqi.geometry import _get_identity_geometries
|
|
5
|
+
from cuqi.utilities import force_ndarray
|
|
3
6
|
from cuqi.distribution import Distribution
|
|
4
7
|
|
|
5
8
|
class Normal(Distribution):
|
|
@@ -27,6 +30,24 @@ class Normal(Distribution):
|
|
|
27
30
|
self.mean = mean
|
|
28
31
|
self.std = std
|
|
29
32
|
|
|
33
|
+
@property
|
|
34
|
+
def mean(self):
|
|
35
|
+
""" Mean of the distribution """
|
|
36
|
+
return self._mean
|
|
37
|
+
|
|
38
|
+
@mean.setter
|
|
39
|
+
def mean(self, value):
|
|
40
|
+
self._mean = force_ndarray(value, flatten=True)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def std(self):
|
|
44
|
+
""" Std of the distribution """
|
|
45
|
+
return self._std
|
|
46
|
+
|
|
47
|
+
@std.setter
|
|
48
|
+
def std(self, value):
|
|
49
|
+
self._std = force_ndarray(value, flatten=True)
|
|
50
|
+
|
|
30
51
|
def pdf(self, x):
|
|
31
52
|
return np.prod(1/(self.std*np.sqrt(2*np.pi))*np.exp(-0.5*((x-self.mean)/self.std)**2))
|
|
32
53
|
|
|
@@ -36,6 +57,19 @@ class Normal(Distribution):
|
|
|
36
57
|
def cdf(self, x):
|
|
37
58
|
return np.prod(0.5*(1 + erf((x-self.mean)/(self.std*np.sqrt(2)))))
|
|
38
59
|
|
|
60
|
+
def _gradient(self, val, *args, **kwargs):
|
|
61
|
+
if not type(self.geometry) in _get_identity_geometries():
|
|
62
|
+
raise NotImplementedError("Gradient not implemented for distribution {} with geometry {}".format(self,self.geometry))
|
|
63
|
+
if not callable(self.mean):
|
|
64
|
+
return -(val-self.mean)/(self.std**2)
|
|
65
|
+
elif hasattr(self.mean, "gradient"): # for likelihood
|
|
66
|
+
model = self.mean
|
|
67
|
+
dev = val - model.forward(*args, **kwargs)
|
|
68
|
+
print(dev)
|
|
69
|
+
return model.gradient(1.0/(np.array(self.std)) @ dev, *args, **kwargs)
|
|
70
|
+
else:
|
|
71
|
+
raise NotImplementedError("Gradient not implemented for distribution {} with location {}".format(self,self.mean))
|
|
72
|
+
|
|
39
73
|
def _sample(self,N=1, rng=None):
|
|
40
74
|
|
|
41
75
|
"""
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from scipy.special import erf
|
|
3
|
+
from cuqi.utilities import force_ndarray
|
|
4
|
+
from cuqi.distribution import Distribution
|
|
5
|
+
from cuqi.distribution import Normal
|
|
6
|
+
|
|
7
|
+
class TruncatedNormal(Distribution):
|
|
8
|
+
"""
|
|
9
|
+
Truncated Normal probability distribution.
|
|
10
|
+
|
|
11
|
+
Generates instance of cuqi.distribution.TruncatedNormal.
|
|
12
|
+
It allows the user to specify upper and lower bounds on random variables
|
|
13
|
+
represented by a Normal distribution. This distribution is suitable for a
|
|
14
|
+
small dimension setup (e.g. `dim`=3 or 4). Using TruncatedNormal
|
|
15
|
+
Distribution with a larger dimension can lead to a high rejection rate when
|
|
16
|
+
used within MCMC samplers.
|
|
17
|
+
|
|
18
|
+
The variables of this distribution are iid.
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
------------
|
|
23
|
+
mean : float or array_like of floats
|
|
24
|
+
mean of distribution
|
|
25
|
+
std : float or array_like of floats
|
|
26
|
+
standard deviation
|
|
27
|
+
low : float or array_like of floats
|
|
28
|
+
lower bound of the distribution
|
|
29
|
+
high : float or array_like of floats
|
|
30
|
+
upper bound of the distribution
|
|
31
|
+
|
|
32
|
+
Example
|
|
33
|
+
-----------
|
|
34
|
+
.. code-block:: python
|
|
35
|
+
|
|
36
|
+
#Generate Normal with mean 0, standard deviation 1 and bounds [-2,2]
|
|
37
|
+
p = cuqi.distribution.TruncatedNormal(mean=0, std=1, low=-2, high=2)
|
|
38
|
+
samples = p.sample(5000)
|
|
39
|
+
"""
|
|
40
|
+
def __init__(self, mean=None, std=None, low=-np.Inf, high=np.Inf, is_symmetric=False, **kwargs):
|
|
41
|
+
# Init from abstract distribution class
|
|
42
|
+
super().__init__(is_symmetric=is_symmetric, **kwargs)
|
|
43
|
+
|
|
44
|
+
# Init specific to this distribution
|
|
45
|
+
self.mean = mean
|
|
46
|
+
self.std = std
|
|
47
|
+
self.low = low
|
|
48
|
+
self.high = high
|
|
49
|
+
|
|
50
|
+
# Init underlying normal distribution
|
|
51
|
+
self._normal = Normal(self.mean, self.std, is_symmetric=True, **kwargs)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def mean(self):
|
|
55
|
+
""" Mean of the distribution """
|
|
56
|
+
return self._mean
|
|
57
|
+
|
|
58
|
+
@mean.setter
|
|
59
|
+
def mean(self, value):
|
|
60
|
+
self._mean = force_ndarray(value, flatten=True)
|
|
61
|
+
if hasattr(self, '_normal'):
|
|
62
|
+
self._normal.mean = self._mean
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def std(self):
|
|
66
|
+
""" Std of the distribution """
|
|
67
|
+
return self._std
|
|
68
|
+
|
|
69
|
+
@std.setter
|
|
70
|
+
def std(self, value):
|
|
71
|
+
self._std = force_ndarray(value, flatten=True)
|
|
72
|
+
if hasattr(self, '_normal'):
|
|
73
|
+
self._normal.std = self._std
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def low(self):
|
|
77
|
+
""" Lower bound of the distribution """
|
|
78
|
+
return self._low
|
|
79
|
+
|
|
80
|
+
@low.setter
|
|
81
|
+
def low(self, value):
|
|
82
|
+
self._low = force_ndarray(value, flatten=True)
|
|
83
|
+
|
|
84
|
+
@property
|
|
85
|
+
def high(self):
|
|
86
|
+
""" Higher bound of the distribution """
|
|
87
|
+
return self._high
|
|
88
|
+
|
|
89
|
+
@high.setter
|
|
90
|
+
def high(self, value):
|
|
91
|
+
self._high = force_ndarray(value, flatten=True)
|
|
92
|
+
|
|
93
|
+
def logpdf(self, x):
|
|
94
|
+
"""
|
|
95
|
+
Computes the unnormalized logpdf at the given values of x.
|
|
96
|
+
"""
|
|
97
|
+
# the unnormalized logpdf
|
|
98
|
+
# check if x falls in the range between np.array a and b
|
|
99
|
+
if np.any(x < self.low) or np.any(x > self.high):
|
|
100
|
+
return -np.Inf
|
|
101
|
+
else:
|
|
102
|
+
return self._normal.logpdf(x)
|
|
103
|
+
|
|
104
|
+
def _gradient(self, x, *args, **kwargs):
|
|
105
|
+
"""
|
|
106
|
+
Computes the gradient of the unnormalized logpdf at the given values of x.
|
|
107
|
+
"""
|
|
108
|
+
# check if x falls in the range between np.array a and b
|
|
109
|
+
if np.any(x < self.low) or np.any(x > self.high):
|
|
110
|
+
return np.NaN*np.ones_like(x)
|
|
111
|
+
else:
|
|
112
|
+
return self._normal.gradient(x, *args, **kwargs)
|
|
113
|
+
|
|
114
|
+
def _sample(self, N=1, rng=None):
|
|
115
|
+
"""
|
|
116
|
+
Generates random samples from the distribution.
|
|
117
|
+
"""
|
|
118
|
+
max_iter = 1e9 # maximum number of trials to avoid infinite loop
|
|
119
|
+
samples = []
|
|
120
|
+
for i in range(int(max_iter)):
|
|
121
|
+
if len(samples) == N:
|
|
122
|
+
break
|
|
123
|
+
sample = self._normal.sample(1,rng)
|
|
124
|
+
if np.all(sample >= self.low) and np.all(sample <= self.high):
|
|
125
|
+
samples.append(sample)
|
|
126
|
+
# raise a error if the number of iterations exceeds max_iter
|
|
127
|
+
if i == max_iter-1:
|
|
128
|
+
raise RuntimeError("Failed to generate {} samples within {} iterations".format(N, max_iter))
|
|
129
|
+
return np.array(samples).T.reshape(-1,N)
|
cuqi/experimental/mcmc/_gibbs.py
CHANGED
|
@@ -136,13 +136,7 @@ class HybridGibbs:
|
|
|
136
136
|
self._set_targets()
|
|
137
137
|
|
|
138
138
|
# Initialize the samplers
|
|
139
|
-
self._initialize_samplers()
|
|
140
|
-
|
|
141
|
-
# Run over pre-sample methods for samplers that have it
|
|
142
|
-
# TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
|
|
143
|
-
# This is not ideal and should be fixed in the future
|
|
144
|
-
for sampler in self.samplers.values():
|
|
145
|
-
self._pre_warmup_and_pre_sample_sampler(sampler)
|
|
139
|
+
self._initialize_samplers()
|
|
146
140
|
|
|
147
141
|
# Validate all targets for samplers.
|
|
148
142
|
self.validate_targets()
|
|
@@ -239,10 +233,6 @@ class HybridGibbs:
|
|
|
239
233
|
sampler.set_state(sampler_state)
|
|
240
234
|
sampler.set_history(sampler_history)
|
|
241
235
|
|
|
242
|
-
# Run pre_warmup and pre_sample methods for sampler
|
|
243
|
-
# TODO. Some samplers (NUTS) seem to require to run _pre_warmup before _pre_sample
|
|
244
|
-
self._pre_warmup_and_pre_sample_sampler(sampler)
|
|
245
|
-
|
|
246
236
|
# Allow for multiple sampling steps in each Gibbs step
|
|
247
237
|
for _ in range(self.num_sampling_steps[par_name]):
|
|
248
238
|
# Sampling step
|
|
@@ -291,10 +281,6 @@ class HybridGibbs:
|
|
|
291
281
|
self.num_sampling_steps[par_name] = 1
|
|
292
282
|
|
|
293
283
|
|
|
294
|
-
def _pre_warmup_and_pre_sample_sampler(self, sampler):
|
|
295
|
-
if hasattr(sampler, '_pre_warmup'): sampler._pre_warmup()
|
|
296
|
-
if hasattr(sampler, '_pre_sample'): sampler._pre_sample()
|
|
297
|
-
|
|
298
284
|
def _set_targets(self):
|
|
299
285
|
""" Set targets for all samplers using the current samples """
|
|
300
286
|
par_names = self.par_names
|
cuqi/experimental/mcmc/_hmc.py
CHANGED
|
@@ -199,6 +199,9 @@ class NUTS(Sampler):
|
|
|
199
199
|
self._reset_run_diagnostic_attributes()
|
|
200
200
|
|
|
201
201
|
def step(self):
|
|
202
|
+
if isinstance(self._epsilon_bar, str) and self._epsilon_bar == "unset":
|
|
203
|
+
self._epsilon_bar = self._epsilon
|
|
204
|
+
|
|
202
205
|
# Convert current_point, logd, and grad to numpy arrays
|
|
203
206
|
# if they are CUQIarray objects
|
|
204
207
|
if isinstance(self.current_point, CUQIarray):
|
|
@@ -212,9 +215,9 @@ class NUTS(Sampler):
|
|
|
212
215
|
self._num_tree_node = 0
|
|
213
216
|
|
|
214
217
|
# copy current point, logd, and grad in local variables
|
|
215
|
-
point_k = self.current_point
|
|
218
|
+
point_k = self.current_point # initial position (parameters)
|
|
216
219
|
logd_k = self.current_target_logd
|
|
217
|
-
grad_k = self.current_target_grad
|
|
220
|
+
grad_k = self.current_target_grad # initial gradient
|
|
218
221
|
|
|
219
222
|
# compute r_k and Hamiltonian
|
|
220
223
|
r_k = self._Kfun(1, 'sample') # resample momentum vector
|
|
@@ -225,9 +228,9 @@ class NUTS(Sampler):
|
|
|
225
228
|
|
|
226
229
|
# initialization
|
|
227
230
|
j, s, n = 0, 1, 1
|
|
228
|
-
point_minus, point_plus =
|
|
229
|
-
grad_minus, grad_plus =
|
|
230
|
-
r_minus, r_plus =
|
|
231
|
+
point_minus, point_plus = point_k.copy(), point_k.copy()
|
|
232
|
+
grad_minus, grad_plus = grad_k.copy(), grad_k.copy()
|
|
233
|
+
r_minus, r_plus = r_k.copy(), r_k.copy()
|
|
231
234
|
|
|
232
235
|
# run NUTS
|
|
233
236
|
acc = 0
|
|
@@ -255,9 +258,14 @@ class NUTS(Sampler):
|
|
|
255
258
|
(np.random.rand() <= alpha2) and \
|
|
256
259
|
(not np.isnan(logd_prime)) and \
|
|
257
260
|
(not np.isinf(logd_prime)):
|
|
258
|
-
self.current_point = point_prime
|
|
259
|
-
|
|
260
|
-
self.
|
|
261
|
+
self.current_point = point_prime.copy()
|
|
262
|
+
# copy if array, else assign if scalar
|
|
263
|
+
self.current_target_logd = (
|
|
264
|
+
logd_prime.copy()
|
|
265
|
+
if isinstance(logd_prime, np.ndarray)
|
|
266
|
+
else logd_prime
|
|
267
|
+
)
|
|
268
|
+
self.current_target_grad = grad_prime.copy()
|
|
261
269
|
acc = 1
|
|
262
270
|
|
|
263
271
|
|
|
@@ -281,6 +289,9 @@ class NUTS(Sampler):
|
|
|
281
289
|
|
|
282
290
|
def tune(self, skip_len, update_count):
|
|
283
291
|
""" adapt epsilon during burn-in using dual averaging"""
|
|
292
|
+
if isinstance(self._epsilon_bar, str) and self._epsilon_bar == "unset":
|
|
293
|
+
self._epsilon_bar = 1
|
|
294
|
+
|
|
284
295
|
k = update_count+1
|
|
285
296
|
|
|
286
297
|
# Fixed parameters that do not change during the run
|
|
@@ -294,26 +305,6 @@ class NUTS(Sampler):
|
|
|
294
305
|
self._epsilon_bar =\
|
|
295
306
|
np.exp(eta*np.log(self._epsilon) +(1-eta)*np.log(self._epsilon_bar))
|
|
296
307
|
|
|
297
|
-
def _pre_warmup(self):
|
|
298
|
-
|
|
299
|
-
# Set up tuning parameters (only first time tuning is called)
|
|
300
|
-
# Note:
|
|
301
|
-
# Parameters changes during the tune run
|
|
302
|
-
# self._epsilon_bar
|
|
303
|
-
# self._H_bar
|
|
304
|
-
# self._epsilon
|
|
305
|
-
# Parameters that does not change during the run
|
|
306
|
-
# self._mu
|
|
307
|
-
self._ensure_initialized()
|
|
308
|
-
if self._epsilon_bar == "unset": # Initial value of epsilon_bar for tuning
|
|
309
|
-
self._epsilon_bar = 1
|
|
310
|
-
|
|
311
|
-
def _pre_sample(self):
|
|
312
|
-
self._ensure_initialized()
|
|
313
|
-
if self._epsilon_bar == "unset": # Initial value of epsilon_bar for sampling
|
|
314
|
-
self._epsilon_bar = self._epsilon
|
|
315
|
-
|
|
316
|
-
|
|
317
308
|
#=========================================================================
|
|
318
309
|
def _nuts_target(self, x): # returns logposterior tuple evaluation-gradient
|
|
319
310
|
return self.target.logd(x), self.target.gradient(x)
|
|
@@ -423,9 +414,14 @@ class NUTS(Sampler):
|
|
|
423
414
|
# Metropolis step
|
|
424
415
|
alpha2 = n_2prime / max(1, (n_prime + n_2prime))
|
|
425
416
|
if (np.random.rand() <= alpha2):
|
|
426
|
-
point_prime =
|
|
427
|
-
|
|
428
|
-
|
|
417
|
+
point_prime = point_2prime.copy()
|
|
418
|
+
# copy if array, else assign if scalar
|
|
419
|
+
logd_prime = (
|
|
420
|
+
logd_2prime.copy()
|
|
421
|
+
if isinstance(logd_2prime, np.ndarray)
|
|
422
|
+
else logd_2prime
|
|
423
|
+
)
|
|
424
|
+
grad_prime = grad_2prime.copy()
|
|
429
425
|
|
|
430
426
|
# update number of particles and stopping criterion
|
|
431
427
|
alpha_prime += alpha_2prime
|
|
@@ -465,4 +461,4 @@ class NUTS(Sampler):
|
|
|
465
461
|
# Store the step size used in iteration k
|
|
466
462
|
self.epsilon_list.append(eps)
|
|
467
463
|
# Store the step size suggestion during adaptation in iteration k
|
|
468
|
-
self.epsilon_bar_list.append(eps_bar)
|
|
464
|
+
self.epsilon_bar_list.append(eps_bar)
|
|
@@ -216,9 +216,6 @@ class Sampler(ABC):
|
|
|
216
216
|
if batch_size > 0:
|
|
217
217
|
batch_handler = _BatchHandler(batch_size, sample_path)
|
|
218
218
|
|
|
219
|
-
# Any code that needs to be run before sampling
|
|
220
|
-
if hasattr(self, "_pre_sample"): self._pre_sample()
|
|
221
|
-
|
|
222
219
|
# Draw samples
|
|
223
220
|
pbar = tqdm(range(Ns), "Sample: ")
|
|
224
221
|
for idx in pbar:
|
|
@@ -260,9 +257,6 @@ class Sampler(ABC):
|
|
|
260
257
|
|
|
261
258
|
tune_interval = max(int(tune_freq * Nb), 1)
|
|
262
259
|
|
|
263
|
-
# Any code that needs to be run before warmup
|
|
264
|
-
if hasattr(self, "_pre_warmup"): self._pre_warmup()
|
|
265
|
-
|
|
266
260
|
# Draw warmup samples with tuning
|
|
267
261
|
pbar = tqdm(range(Nb), "Warmup: ")
|
|
268
262
|
for idx in pbar:
|
|
@@ -566,4 +560,4 @@ class _BatchHandler:
|
|
|
566
560
|
|
|
567
561
|
def finalize(self):
|
|
568
562
|
""" Finalize the batch handler. Flush any remaining samples to disk. """
|
|
569
|
-
self.flush()
|
|
563
|
+
self.flush()
|
|
File without changes
|
|
File without changes
|