inference-tools 0.14.1__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference/_version.py +16 -3
- inference/approx/__init__.py +6 -2
- inference/approx/conditional.py +43 -8
- inference/gp/covariance.py +8 -16
- inference/gp/inversion.py +12 -24
- inference/gp/optimisation.py +2 -4
- inference/gp/regression.py +26 -52
- inference/likelihoods.py +2 -4
- inference/mcmc/base.py +33 -13
- inference/mcmc/ensemble.py +14 -28
- inference/mcmc/gibbs.py +3 -5
- inference/mcmc/hmc/__init__.py +3 -5
- inference/mcmc/hmc/mass.py +37 -8
- inference/mcmc/parallel.py +2 -4
- inference/mcmc/pca.py +4 -8
- inference/mcmc/utilities.py +10 -20
- inference/pdf/base.py +2 -4
- inference/pdf/hdi.py +12 -24
- inference/pdf/kde.py +2 -4
- inference/plotting.py +68 -43
- inference/priors.py +28 -56
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.3.dist-info}/METADATA +1 -1
- inference_tools-0.14.3.dist-info/RECORD +35 -0
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.3.dist-info}/WHEEL +1 -1
- inference_tools-0.14.1.dist-info/RECORD +0 -35
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.3.dist-info}/licenses/LICENSE +0 -0
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.3.dist-info}/top_level.txt +0 -0
inference/mcmc/gibbs.py
CHANGED
|
@@ -253,7 +253,7 @@ class MetropolisChain(MarkovChain):
|
|
|
253
253
|
|
|
254
254
|
if posterior is not None:
|
|
255
255
|
self.posterior = posterior
|
|
256
|
-
|
|
256
|
+
self._validate_posterior(posterior=posterior, start=start)
|
|
257
257
|
# if widths are not specified, take 5% of the starting values (unless they're zero)
|
|
258
258
|
if widths is None:
|
|
259
259
|
widths = [v * 0.05 if v != 0 else 1.0 for v in start]
|
|
@@ -272,13 +272,11 @@ class MetropolisChain(MarkovChain):
|
|
|
272
272
|
|
|
273
273
|
# check posterior value of chain starting point is finite
|
|
274
274
|
if not isfinite(self.probs[0]):
|
|
275
|
-
ValueError(
|
|
276
|
-
"""\n
|
|
275
|
+
ValueError("""\n
|
|
277
276
|
\r[ MetropolisChain error ]
|
|
278
277
|
\r>> 'posterior' argument callable returns a non-finite value
|
|
279
278
|
\r>> for the starting position given to the 'start' argument.
|
|
280
|
-
"""
|
|
281
|
-
)
|
|
279
|
+
""")
|
|
282
280
|
|
|
283
281
|
self.display_progress = display_progress
|
|
284
282
|
self.ProgressPrinter = ChainProgressPrinter(
|
inference/mcmc/hmc/__init__.py
CHANGED
|
@@ -87,7 +87,7 @@ class HamiltonianChain(MarkovChain):
|
|
|
87
87
|
start = start if isinstance(start, ndarray) else array(start)
|
|
88
88
|
start = start if start.dtype is float64 else start.astype(float64)
|
|
89
89
|
assert start.ndim == 1
|
|
90
|
-
|
|
90
|
+
self._validate_posterior(posterior=posterior, start=start)
|
|
91
91
|
self.theta = [start]
|
|
92
92
|
self.probs = [self.posterior(start) * self.inv_temp]
|
|
93
93
|
self.leapfrog_steps = [0]
|
|
@@ -149,12 +149,10 @@ class HamiltonianChain(MarkovChain):
|
|
|
149
149
|
if (accept_prob >= 1) or (self.rng.random() <= accept_prob):
|
|
150
150
|
break
|
|
151
151
|
else:
|
|
152
|
-
raise ValueError(
|
|
153
|
-
f"""\n
|
|
152
|
+
raise ValueError(f"""\n
|
|
154
153
|
\r[ HamiltonianChain error ]
|
|
155
154
|
\r>> Failed to take step within maximum allowed attempts of {self.max_attempts}
|
|
156
|
-
"""
|
|
157
|
-
)
|
|
155
|
+
""")
|
|
158
156
|
|
|
159
157
|
self.theta.append(t)
|
|
160
158
|
self.probs.append(p)
|
inference/mcmc/hmc/mass.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Union
|
|
|
3
3
|
from numpy import ndarray, sqrt, eye, isscalar
|
|
4
4
|
from numpy.random import Generator
|
|
5
5
|
from numpy.linalg import cholesky
|
|
6
|
-
from scipy.linalg import solve_triangular
|
|
6
|
+
from scipy.linalg import solve_triangular, issymmetric
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class ParticleMass(ABC):
|
|
@@ -37,12 +37,43 @@ class VectorMass(ScalarMass):
|
|
|
37
37
|
assert inv_mass.ndim == 1
|
|
38
38
|
assert inv_mass.size == n_parameters
|
|
39
39
|
|
|
40
|
+
valid_variances = (
|
|
41
|
+
inv_mass.ndim == 1
|
|
42
|
+
and inv_mass.size == n_parameters
|
|
43
|
+
and (inv_mass > 0.0).all()
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if not valid_variances:
|
|
47
|
+
raise ValueError(f"""\n
|
|
48
|
+
\r[ VectorMass error ]
|
|
49
|
+
\r>> The inverse-mass vector must be a 1D array and have size
|
|
50
|
+
\r>> equal to the given number of model parameters ({n_parameters})
|
|
51
|
+
\r>> and contain only positive values.
|
|
52
|
+
""")
|
|
53
|
+
|
|
40
54
|
|
|
41
55
|
class MatrixMass(ParticleMass):
|
|
42
56
|
def __init__(self, inv_mass: ndarray, n_parameters: int):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
57
|
+
|
|
58
|
+
valid_covariance = (
|
|
59
|
+
inv_mass.ndim == 2
|
|
60
|
+
and inv_mass.shape[0] == inv_mass.shape[1]
|
|
61
|
+
and issymmetric(inv_mass)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if not valid_covariance:
|
|
65
|
+
raise ValueError("""\n
|
|
66
|
+
\r[ MatrixMass error ]
|
|
67
|
+
\r>> The given inverse-mass matrix must be a valid covariance matrix,
|
|
68
|
+
\r>> i.e. 2 dimensional, square and symmetric.
|
|
69
|
+
""")
|
|
70
|
+
|
|
71
|
+
if inv_mass.shape[0] != n_parameters:
|
|
72
|
+
raise ValueError(f"""\n
|
|
73
|
+
\r[ MatrixMass error ]
|
|
74
|
+
\r>> The dimensions of the given inverse-mass matrix {inv_mass.shape}
|
|
75
|
+
\r>> do not match the given number of model parameters ({n_parameters}).
|
|
76
|
+
""")
|
|
46
77
|
|
|
47
78
|
self.inv_mass = inv_mass
|
|
48
79
|
self.n_parameters = n_parameters
|
|
@@ -64,15 +95,13 @@ def get_particle_mass(
|
|
|
64
95
|
return ScalarMass(inverse_mass, n_parameters)
|
|
65
96
|
|
|
66
97
|
if not isinstance(inverse_mass, ndarray):
|
|
67
|
-
raise TypeError(
|
|
68
|
-
f"""\n
|
|
98
|
+
raise TypeError(f"""\n
|
|
69
99
|
\r[ HamiltonianChain error ]
|
|
70
100
|
\r>> The value given to the 'inverse_mass' keyword argument must be either
|
|
71
101
|
\r>> a scalar type (e.g. int or float), or a numpy.ndarray.
|
|
72
102
|
\r>> Instead, the given value has type:
|
|
73
103
|
\r>> {type(inverse_mass)}
|
|
74
|
-
"""
|
|
75
|
-
)
|
|
104
|
+
""")
|
|
76
105
|
|
|
77
106
|
if inverse_mass.ndim == 1:
|
|
78
107
|
return VectorMass(inverse_mass, n_parameters)
|
inference/mcmc/parallel.py
CHANGED
|
@@ -116,12 +116,10 @@ class ParallelTempering:
|
|
|
116
116
|
self.successful_swaps = zeros([self.N_chains, self.N_chains])
|
|
117
117
|
|
|
118
118
|
if sorted(self.temperatures) != self.temperatures:
|
|
119
|
-
warn(
|
|
120
|
-
"""
|
|
119
|
+
warn("""
|
|
121
120
|
The list of Markov-chain objects passed to ParallelTempering
|
|
122
121
|
should be sorted in order of increasing chain temperature.
|
|
123
|
-
"""
|
|
124
|
-
)
|
|
122
|
+
""")
|
|
125
123
|
|
|
126
124
|
# Spawn a separate process for each chain object
|
|
127
125
|
for chn in chains:
|
inference/mcmc/pca.py
CHANGED
|
@@ -278,22 +278,18 @@ class PcaChain(MetropolisChain):
|
|
|
278
278
|
return chain
|
|
279
279
|
|
|
280
280
|
def set_non_negative(self, *args, **kwargs):
|
|
281
|
-
warn(
|
|
282
|
-
"""
|
|
281
|
+
warn("""
|
|
283
282
|
The set_non_negative method is not available for PcaChain:
|
|
284
283
|
Limits on parameters should instead be set using
|
|
285
284
|
the parameter_boundaries keyword argument.
|
|
286
|
-
"""
|
|
287
|
-
)
|
|
285
|
+
""")
|
|
288
286
|
|
|
289
287
|
def set_boundaries(self, *args, **kwargs):
|
|
290
|
-
warn(
|
|
291
|
-
"""
|
|
288
|
+
warn("""
|
|
292
289
|
The set_boundaries method is not available for PcaChain:
|
|
293
290
|
Limits on parameters should instead be set using
|
|
294
291
|
the parameter_boundaries keyword argument.
|
|
295
|
-
"""
|
|
296
|
-
)
|
|
292
|
+
""")
|
|
297
293
|
|
|
298
294
|
def pass_through(self, prop):
|
|
299
295
|
return prop
|
inference/mcmc/utilities.py
CHANGED
|
@@ -101,51 +101,41 @@ class Bounds:
|
|
|
101
101
|
self.upper = upper if isinstance(upper, ndarray) else array(upper).squeeze()
|
|
102
102
|
|
|
103
103
|
if self.lower.ndim > 1 or self.upper.ndim > 1:
|
|
104
|
-
raise ValueError(
|
|
105
|
-
f"""\n
|
|
104
|
+
raise ValueError(f"""\n
|
|
106
105
|
\r[ {error_source} error ]
|
|
107
106
|
\r>> Lower and upper bounds must be one-dimensional arrays, but
|
|
108
107
|
\r>> instead have dimensions {self.lower.ndim} and {self.upper.ndim} respectively.
|
|
109
|
-
"""
|
|
110
|
-
)
|
|
108
|
+
""")
|
|
111
109
|
|
|
112
110
|
if self.lower.size != self.upper.size:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f"""\n
|
|
111
|
+
raise ValueError(f"""\n
|
|
115
112
|
\r[ {error_source} error ]
|
|
116
113
|
\r>> Lower and upper bounds must be arrays of equal size, but
|
|
117
114
|
\r>> instead have sizes {self.lower.size} and {self.upper.size} respectively.
|
|
118
|
-
"""
|
|
119
|
-
)
|
|
115
|
+
""")
|
|
120
116
|
|
|
121
117
|
if (self.lower >= self.upper).any():
|
|
122
|
-
raise ValueError(
|
|
123
|
-
f"""\n
|
|
118
|
+
raise ValueError(f"""\n
|
|
124
119
|
\r[ {error_source} error ]
|
|
125
120
|
\r>> All given upper bounds must be larger than the corresponding lower bounds.
|
|
126
|
-
"""
|
|
127
|
-
)
|
|
121
|
+
""")
|
|
128
122
|
|
|
129
123
|
self.width = self.upper - self.lower
|
|
130
124
|
self.n_bounds = self.width.size
|
|
131
125
|
|
|
132
126
|
def validate_start_point(self, start: ndarray, error_source="Bounds"):
|
|
133
127
|
if self.n_bounds != start.size:
|
|
134
|
-
raise ValueError(
|
|
135
|
-
f"""\n
|
|
128
|
+
raise ValueError(f"""\n
|
|
136
129
|
\r[ {error_source} error ]
|
|
137
130
|
\r>> The number of parameters ({start.size}) does not
|
|
138
131
|
\r>> match the given number of bounds ({self.n_bounds}).
|
|
139
|
-
"""
|
|
140
|
-
)
|
|
132
|
+
""")
|
|
141
133
|
|
|
142
134
|
if not self.inside(start):
|
|
143
|
-
raise ValueError(
|
|
144
|
-
f"""\n
|
|
135
|
+
raise ValueError(f"""\n
|
|
145
136
|
\r[ {error_source} error ]
|
|
146
137
|
\r>> Starting location for the chain is outside specified bounds.
|
|
147
|
-
"""
|
|
148
|
-
)
|
|
138
|
+
""")
|
|
149
139
|
|
|
150
140
|
def reflect(self, theta: ndarray) -> ndarray:
|
|
151
141
|
q, rem = np_divmod(theta - self.lower, self.width)
|
inference/pdf/base.py
CHANGED
|
@@ -39,13 +39,11 @@ class DensityEstimator(ABC):
|
|
|
39
39
|
in the form ``(lower_limit, upper_limit)``.
|
|
40
40
|
"""
|
|
41
41
|
if not 0.0 < fraction < 1.0:
|
|
42
|
-
raise ValueError(
|
|
43
|
-
f"""\n
|
|
42
|
+
raise ValueError(f"""\n
|
|
44
43
|
\r[ {self.__class__.__name__} error ]
|
|
45
44
|
\r>> The 'fraction' argument must have a value greater than
|
|
46
45
|
\r>> zero and less than one, but the value given was {fraction}.
|
|
47
|
-
"""
|
|
48
|
-
)
|
|
46
|
+
""")
|
|
49
47
|
# use the sample to estimate the HDI
|
|
50
48
|
lwr, upr = sample_hdi(self.sample, fraction=fraction)
|
|
51
49
|
# switch variables to the centre and width of the interval
|
inference/pdf/hdi.py
CHANGED
|
@@ -25,37 +25,31 @@ def sample_hdi(sample: ndarray, fraction: float) -> ndarray:
|
|
|
25
25
|
|
|
26
26
|
# verify inputs are valid
|
|
27
27
|
if not 0.0 < fraction < 1.0:
|
|
28
|
-
raise ValueError(
|
|
29
|
-
f"""\n
|
|
28
|
+
raise ValueError(f"""\n
|
|
30
29
|
\r[ sample_hdi error ]
|
|
31
30
|
\r>> The 'fraction' argument must be a float between 0 and 1,
|
|
32
31
|
\r>> but the value given was {fraction}.
|
|
33
|
-
"""
|
|
34
|
-
)
|
|
32
|
+
""")
|
|
35
33
|
|
|
36
34
|
if isinstance(sample, ndarray):
|
|
37
35
|
s = sample.copy()
|
|
38
36
|
elif isinstance(sample, Sequence):
|
|
39
37
|
s = array(sample)
|
|
40
38
|
else:
|
|
41
|
-
raise ValueError(
|
|
42
|
-
f"""\n
|
|
39
|
+
raise ValueError(f"""\n
|
|
43
40
|
\r[ sample_hdi error ]
|
|
44
41
|
\r>> The 'sample' argument should be a numpy.ndarray or a
|
|
45
42
|
\r>> Sequence which can be converted to an array, but
|
|
46
43
|
\r>> instead has type {type(sample)}.
|
|
47
|
-
"""
|
|
48
|
-
)
|
|
44
|
+
""")
|
|
49
45
|
|
|
50
46
|
if s.ndim > 2 or s.ndim == 0:
|
|
51
|
-
raise ValueError(
|
|
52
|
-
f"""\n
|
|
47
|
+
raise ValueError(f"""\n
|
|
53
48
|
\r[ sample_hdi error ]
|
|
54
49
|
\r>> The 'sample' argument should be a numpy.ndarray
|
|
55
50
|
\r>> with either one or two dimensions, but the given
|
|
56
51
|
\r>> array has dimensionality {s.ndim}.
|
|
57
|
-
"""
|
|
58
|
-
)
|
|
52
|
+
""")
|
|
59
53
|
|
|
60
54
|
if s.ndim == 1:
|
|
61
55
|
s.resize([s.size, 1])
|
|
@@ -64,31 +58,25 @@ def sample_hdi(sample: ndarray, fraction: float) -> ndarray:
|
|
|
64
58
|
L = int(fraction * n_samples)
|
|
65
59
|
|
|
66
60
|
if n_samples < 2:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
f"""\n
|
|
61
|
+
raise ValueError(f"""\n
|
|
69
62
|
\r[ sample_hdi error ]
|
|
70
63
|
\r>> The first dimension of the given 'sample' array must
|
|
71
64
|
\r>> have have a length of at least 2.
|
|
72
|
-
"""
|
|
73
|
-
)
|
|
65
|
+
""")
|
|
74
66
|
|
|
75
67
|
# check that we have enough samples to estimate the HDI for the chosen fraction
|
|
76
68
|
if n_samples <= L:
|
|
77
|
-
warn(
|
|
78
|
-
f"""\n
|
|
69
|
+
warn(f"""\n
|
|
79
70
|
\r[ sample_hdi warning ]
|
|
80
71
|
\r>> The given number of samples is insufficient to estimate the interval
|
|
81
72
|
\r>> for the given fraction.
|
|
82
|
-
"""
|
|
83
|
-
)
|
|
73
|
+
""")
|
|
84
74
|
|
|
85
75
|
elif n_samples - L < 20:
|
|
86
|
-
warn(
|
|
87
|
-
f"""\n
|
|
76
|
+
warn(f"""\n
|
|
88
77
|
\r[ sample_hdi warning ]
|
|
89
78
|
\r>> n_samples * (1 - fraction) is small - calculated interval may be inaccurate.
|
|
90
|
-
"""
|
|
91
|
-
)
|
|
79
|
+
""")
|
|
92
80
|
|
|
93
81
|
# check that we have enough samples to estimate the HDI for the chosen fraction
|
|
94
82
|
s.sort(axis=0)
|
inference/pdf/kde.py
CHANGED
|
@@ -51,13 +51,11 @@ class GaussianKDE(DensityEstimator):
|
|
|
51
51
|
self.max_cvs = max_cv_samples
|
|
52
52
|
|
|
53
53
|
if self.sample.size < 3:
|
|
54
|
-
raise ValueError(
|
|
55
|
-
"""\n
|
|
54
|
+
raise ValueError("""\n
|
|
56
55
|
\r[ GaussianKDE error ]
|
|
57
56
|
\r>> Not enough samples were given to estimate the PDF.
|
|
58
57
|
\r>> At least 3 samples are required.
|
|
59
|
-
"""
|
|
60
|
-
)
|
|
58
|
+
""")
|
|
61
59
|
|
|
62
60
|
if bandwidth is None:
|
|
63
61
|
self.h = self.simple_bandwidth_estimator() # very simple bandwidth estimate
|
inference/plotting.py
CHANGED
|
@@ -89,23 +89,19 @@ def matrix_plot(
|
|
|
89
89
|
labels = [f"param {i}" for i in range(N_par)]
|
|
90
90
|
else:
|
|
91
91
|
if len(labels) != N_par:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
"""\n
|
|
92
|
+
raise ValueError("""\n
|
|
94
93
|
\r[ matrix_plot error ]
|
|
95
94
|
\r>> The number of labels given does not match
|
|
96
95
|
\r>> the number of plotted parameters.
|
|
97
|
-
"""
|
|
98
|
-
)
|
|
96
|
+
""")
|
|
99
97
|
|
|
100
98
|
if reference is not None:
|
|
101
99
|
if len(reference) != N_par:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"""\n
|
|
100
|
+
raise ValueError("""\n
|
|
104
101
|
\r[ matrix_plot error ]
|
|
105
102
|
\r>> The number of reference values given does not match
|
|
106
103
|
\r>> the number of plotted parameters.
|
|
107
|
-
"""
|
|
108
|
-
)
|
|
104
|
+
""")
|
|
109
105
|
# check that given plot style is valid, else default to a histogram
|
|
110
106
|
if plot_style not in ["contour", "hdi", "histogram", "scatter"]:
|
|
111
107
|
plot_style = "contour"
|
|
@@ -115,13 +111,11 @@ def matrix_plot(
|
|
|
115
111
|
|
|
116
112
|
iterable = hasattr(hdi_fractions, "__iter__")
|
|
117
113
|
if not iterable or not all(0 < f < 1 for f in hdi_fractions):
|
|
118
|
-
raise ValueError(
|
|
119
|
-
"""\n
|
|
114
|
+
raise ValueError("""\n
|
|
120
115
|
\r[ matrix_plot error ]
|
|
121
116
|
\r>> The 'hdi_fractions' argument must be given as an
|
|
122
117
|
\r>> iterable of floats, each in the range [0, 1].
|
|
123
|
-
"""
|
|
124
|
-
)
|
|
118
|
+
""")
|
|
125
119
|
|
|
126
120
|
# by default, we suppress axis ticks if there are 6 parameters or more to keep things tidy
|
|
127
121
|
if show_ticks is None:
|
|
@@ -373,10 +367,11 @@ def hdi_plot(
|
|
|
373
367
|
x: ndarray,
|
|
374
368
|
sample: ndarray,
|
|
375
369
|
intervals: Sequence[float] = (0.65, 0.95),
|
|
376
|
-
|
|
370
|
+
color: str = "C0",
|
|
377
371
|
axis=None,
|
|
378
|
-
|
|
379
|
-
|
|
372
|
+
plot_mean=True,
|
|
373
|
+
labels=True,
|
|
374
|
+
interval_alpha: Sequence[float] = None,
|
|
380
375
|
):
|
|
381
376
|
"""
|
|
382
377
|
Plot highest-density intervals for a given sample of model realisations.
|
|
@@ -389,27 +384,33 @@ def hdi_plot(
|
|
|
389
384
|
where ``n`` is the number of samples.
|
|
390
385
|
|
|
391
386
|
:param intervals: \
|
|
392
|
-
|
|
387
|
+
The fractions of the total probability contained in each interval which is to be
|
|
388
|
+
plotted as a Sequence of floats. All given values must be in the range [0, 1].
|
|
393
389
|
|
|
394
|
-
:param str
|
|
395
|
-
The
|
|
396
|
-
a valid
|
|
390
|
+
:param str color: \
|
|
391
|
+
The color to be used for plotting the intervals. Must be the name of
|
|
392
|
+
a valid ``matplotlib`` color.
|
|
397
393
|
|
|
398
394
|
:param axis: \
|
|
399
395
|
A ``matplotlib.pyplot`` axis object which will be used to plot the intervals.
|
|
400
396
|
|
|
401
|
-
:param bool
|
|
402
|
-
If ``True``,
|
|
403
|
-
|
|
397
|
+
:param bool plot_mean: \
|
|
398
|
+
If ``True``, the mean of the samples is also plotted on top of the
|
|
399
|
+
highest-density intervals.
|
|
404
400
|
|
|
405
|
-
:param
|
|
406
|
-
|
|
407
|
-
|
|
401
|
+
:param bool labels: \
|
|
402
|
+
If ``True``, then labels will be assigned to each plot element such that they
|
|
403
|
+
appear in the legend when using ``matplotlib.pyplot.legend``.
|
|
404
|
+
|
|
405
|
+
:param interval_alpha: \
|
|
406
|
+
A sequence of floats in the range [0, 1] specifying the 'alpha' value (which sets
|
|
407
|
+
the color transparency) which is used when coloring the intervals for each given
|
|
408
|
+
probability fraction.
|
|
408
409
|
"""
|
|
409
410
|
# order the intervals from highest to lowest
|
|
410
411
|
intervals = array(intervals)
|
|
411
|
-
intervals.
|
|
412
|
-
intervals = intervals[
|
|
412
|
+
sorter = intervals.argsort()
|
|
413
|
+
intervals = intervals[sorter]
|
|
413
414
|
|
|
414
415
|
# check that all the intervals are valid:
|
|
415
416
|
if not all((intervals > 0.0) & (intervals < 1.0)):
|
|
@@ -425,32 +426,56 @@ def hdi_plot(
|
|
|
425
426
|
|
|
426
427
|
# sort the sample data
|
|
427
428
|
s.sort(axis=0)
|
|
428
|
-
n = s.shape[0]
|
|
429
|
-
|
|
430
|
-
if colormap in colormaps:
|
|
431
|
-
cmap = colormaps[colormap]
|
|
432
|
-
else:
|
|
433
|
-
cmap = colormaps["Blues"]
|
|
434
|
-
warn(f"'{colormap}' is not a valid colormap from matplotlib.colormaps")
|
|
435
429
|
|
|
436
|
-
if
|
|
430
|
+
if interval_alpha is None:
|
|
437
431
|
# construct the colors for each interval
|
|
438
|
-
lwr = 0.
|
|
439
|
-
upr =
|
|
440
|
-
|
|
432
|
+
lwr = 0.15
|
|
433
|
+
upr = 0.9
|
|
434
|
+
interval_alpha = (upr - lwr) * (1 - intervals) + lwr
|
|
435
|
+
else:
|
|
436
|
+
interval_alpha = array(interval_alpha)
|
|
437
|
+
assert interval_alpha.ndim == 1
|
|
438
|
+
assert interval_alpha.size == interval_alpha.size
|
|
439
|
+
interval_alpha = interval_alpha[sorter]
|
|
441
440
|
|
|
442
|
-
|
|
441
|
+
valid_alpha = (interval_alpha >= 0.0) & (interval_alpha <= 1.0)
|
|
442
|
+
if not valid_alpha:
|
|
443
|
+
raise ValueError("Given 'alpha' values must be in the range [0, 1]")
|
|
443
444
|
|
|
444
445
|
# if not plotting axis is given, then use default pyplot
|
|
445
446
|
if axis is None:
|
|
446
447
|
_, axis = plt.subplots()
|
|
447
448
|
|
|
448
|
-
|
|
449
|
-
|
|
449
|
+
if plot_mean:
|
|
450
|
+
lab = "mean" if labels else None
|
|
451
|
+
axis.plot(x, s.mean(axis=0), color=color, lw=2, label=lab)
|
|
452
|
+
|
|
453
|
+
# Calculate HDI for each interval fraction
|
|
454
|
+
lower = []
|
|
455
|
+
upper = []
|
|
456
|
+
for frac in intervals:
|
|
450
457
|
lwr, upr = sample_hdi(s, fraction=frac)
|
|
451
|
-
|
|
452
|
-
|
|
458
|
+
lower.append(lwr)
|
|
459
|
+
upper.append(upr)
|
|
460
|
+
|
|
461
|
+
lab = f"{int(100 * intervals[0])}% HDI" if labels else None
|
|
462
|
+
axis.fill_between(
|
|
463
|
+
x, lower[0], upper[0], color=color, label=lab, alpha=interval_alpha[0]
|
|
464
|
+
)
|
|
453
465
|
|
|
466
|
+
for i in range(len(intervals) - 1):
|
|
467
|
+
lab = f"{int(100 * intervals[i + 1])}% HDI" if labels else None
|
|
468
|
+
axis.fill_between(
|
|
469
|
+
x,
|
|
470
|
+
lower[i + 1],
|
|
471
|
+
lower[i],
|
|
472
|
+
color=color,
|
|
473
|
+
label=lab,
|
|
474
|
+
alpha=interval_alpha[i + 1],
|
|
475
|
+
)
|
|
476
|
+
axis.fill_between(
|
|
477
|
+
x, upper[i], upper[i + 1], color=color, alpha=interval_alpha[i + 1]
|
|
478
|
+
)
|
|
454
479
|
return axis
|
|
455
480
|
|
|
456
481
|
|