inference-tools 0.14.2__py3-none-any.whl → 0.14.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference/_version.py +16 -3
- inference/approx/conditional.py +4 -8
- inference/gp/covariance.py +8 -16
- inference/gp/inversion.py +12 -24
- inference/gp/optimisation.py +2 -4
- inference/gp/regression.py +26 -52
- inference/likelihoods.py +2 -4
- inference/mcmc/base.py +12 -24
- inference/mcmc/ensemble.py +14 -28
- inference/mcmc/gibbs.py +2 -4
- inference/mcmc/hmc/__init__.py +2 -4
- inference/mcmc/hmc/mass.py +8 -16
- inference/mcmc/parallel.py +2 -4
- inference/mcmc/pca.py +4 -8
- inference/mcmc/utilities.py +10 -20
- inference/pdf/base.py +2 -4
- inference/pdf/hdi.py +12 -24
- inference/pdf/kde.py +2 -4
- inference/plotting.py +68 -43
- inference/priors.py +28 -56
- {inference_tools-0.14.2.dist-info → inference_tools-0.14.3.dist-info}/METADATA +1 -1
- inference_tools-0.14.3.dist-info/RECORD +35 -0
- {inference_tools-0.14.2.dist-info → inference_tools-0.14.3.dist-info}/WHEEL +1 -1
- inference_tools-0.14.2.dist-info/RECORD +0 -35
- {inference_tools-0.14.2.dist-info → inference_tools-0.14.3.dist-info}/licenses/LICENSE +0 -0
- {inference_tools-0.14.2.dist-info → inference_tools-0.14.3.dist-info}/top_level.txt +0 -0
inference/mcmc/hmc/mass.py
CHANGED
|
@@ -44,14 +44,12 @@ class VectorMass(ScalarMass):
|
|
|
44
44
|
)
|
|
45
45
|
|
|
46
46
|
if not valid_variances:
|
|
47
|
-
raise ValueError(
|
|
48
|
-
f"""\n
|
|
47
|
+
raise ValueError(f"""\n
|
|
49
48
|
\r[ VectorMass error ]
|
|
50
49
|
\r>> The inverse-mass vector must be a 1D array and have size
|
|
51
50
|
\r>> equal to the given number of model parameters ({n_parameters})
|
|
52
51
|
\r>> and contain only positive values.
|
|
53
|
-
"""
|
|
54
|
-
)
|
|
52
|
+
""")
|
|
55
53
|
|
|
56
54
|
|
|
57
55
|
class MatrixMass(ParticleMass):
|
|
@@ -64,22 +62,18 @@ class MatrixMass(ParticleMass):
|
|
|
64
62
|
)
|
|
65
63
|
|
|
66
64
|
if not valid_covariance:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
"""\n
|
|
65
|
+
raise ValueError("""\n
|
|
69
66
|
\r[ MatrixMass error ]
|
|
70
67
|
\r>> The given inverse-mass matrix must be a valid covariance matrix,
|
|
71
68
|
\r>> i.e. 2 dimensional, square and symmetric.
|
|
72
|
-
"""
|
|
73
|
-
)
|
|
69
|
+
""")
|
|
74
70
|
|
|
75
71
|
if inv_mass.shape[0] != n_parameters:
|
|
76
|
-
raise ValueError(
|
|
77
|
-
f"""\n
|
|
72
|
+
raise ValueError(f"""\n
|
|
78
73
|
\r[ MatrixMass error ]
|
|
79
74
|
\r>> The dimensions of the given inverse-mass matrix {inv_mass.shape}
|
|
80
75
|
\r>> do not match the given number of model parameters ({n_parameters}).
|
|
81
|
-
"""
|
|
82
|
-
)
|
|
76
|
+
""")
|
|
83
77
|
|
|
84
78
|
self.inv_mass = inv_mass
|
|
85
79
|
self.n_parameters = n_parameters
|
|
@@ -101,15 +95,13 @@ def get_particle_mass(
|
|
|
101
95
|
return ScalarMass(inverse_mass, n_parameters)
|
|
102
96
|
|
|
103
97
|
if not isinstance(inverse_mass, ndarray):
|
|
104
|
-
raise TypeError(
|
|
105
|
-
f"""\n
|
|
98
|
+
raise TypeError(f"""\n
|
|
106
99
|
\r[ HamiltonianChain error ]
|
|
107
100
|
\r>> The value given to the 'inverse_mass' keyword argument must be either
|
|
108
101
|
\r>> a scalar type (e.g. int or float), or a numpy.ndarray.
|
|
109
102
|
\r>> Instead, the given value has type:
|
|
110
103
|
\r>> {type(inverse_mass)}
|
|
111
|
-
"""
|
|
112
|
-
)
|
|
104
|
+
""")
|
|
113
105
|
|
|
114
106
|
if inverse_mass.ndim == 1:
|
|
115
107
|
return VectorMass(inverse_mass, n_parameters)
|
inference/mcmc/parallel.py
CHANGED
|
@@ -116,12 +116,10 @@ class ParallelTempering:
|
|
|
116
116
|
self.successful_swaps = zeros([self.N_chains, self.N_chains])
|
|
117
117
|
|
|
118
118
|
if sorted(self.temperatures) != self.temperatures:
|
|
119
|
-
warn(
|
|
120
|
-
"""
|
|
119
|
+
warn("""
|
|
121
120
|
The list of Markov-chain objects passed to ParallelTempering
|
|
122
121
|
should be sorted in order of increasing chain temperature.
|
|
123
|
-
"""
|
|
124
|
-
)
|
|
122
|
+
""")
|
|
125
123
|
|
|
126
124
|
# Spawn a separate process for each chain object
|
|
127
125
|
for chn in chains:
|
inference/mcmc/pca.py
CHANGED
|
@@ -278,22 +278,18 @@ class PcaChain(MetropolisChain):
|
|
|
278
278
|
return chain
|
|
279
279
|
|
|
280
280
|
def set_non_negative(self, *args, **kwargs):
|
|
281
|
-
warn(
|
|
282
|
-
"""
|
|
281
|
+
warn("""
|
|
283
282
|
The set_non_negative method is not available for PcaChain:
|
|
284
283
|
Limits on parameters should instead be set using
|
|
285
284
|
the parameter_boundaries keyword argument.
|
|
286
|
-
"""
|
|
287
|
-
)
|
|
285
|
+
""")
|
|
288
286
|
|
|
289
287
|
def set_boundaries(self, *args, **kwargs):
|
|
290
|
-
warn(
|
|
291
|
-
"""
|
|
288
|
+
warn("""
|
|
292
289
|
The set_boundaries method is not available for PcaChain:
|
|
293
290
|
Limits on parameters should instead be set using
|
|
294
291
|
the parameter_boundaries keyword argument.
|
|
295
|
-
"""
|
|
296
|
-
)
|
|
292
|
+
""")
|
|
297
293
|
|
|
298
294
|
def pass_through(self, prop):
|
|
299
295
|
return prop
|
inference/mcmc/utilities.py
CHANGED
|
@@ -101,51 +101,41 @@ class Bounds:
|
|
|
101
101
|
self.upper = upper if isinstance(upper, ndarray) else array(upper).squeeze()
|
|
102
102
|
|
|
103
103
|
if self.lower.ndim > 1 or self.upper.ndim > 1:
|
|
104
|
-
raise ValueError(
|
|
105
|
-
f"""\n
|
|
104
|
+
raise ValueError(f"""\n
|
|
106
105
|
\r[ {error_source} error ]
|
|
107
106
|
\r>> Lower and upper bounds must be one-dimensional arrays, but
|
|
108
107
|
\r>> instead have dimensions {self.lower.ndim} and {self.upper.ndim} respectively.
|
|
109
|
-
"""
|
|
110
|
-
)
|
|
108
|
+
""")
|
|
111
109
|
|
|
112
110
|
if self.lower.size != self.upper.size:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f"""\n
|
|
111
|
+
raise ValueError(f"""\n
|
|
115
112
|
\r[ {error_source} error ]
|
|
116
113
|
\r>> Lower and upper bounds must be arrays of equal size, but
|
|
117
114
|
\r>> instead have sizes {self.lower.size} and {self.upper.size} respectively.
|
|
118
|
-
"""
|
|
119
|
-
)
|
|
115
|
+
""")
|
|
120
116
|
|
|
121
117
|
if (self.lower >= self.upper).any():
|
|
122
|
-
raise ValueError(
|
|
123
|
-
f"""\n
|
|
118
|
+
raise ValueError(f"""\n
|
|
124
119
|
\r[ {error_source} error ]
|
|
125
120
|
\r>> All given upper bounds must be larger than the corresponding lower bounds.
|
|
126
|
-
"""
|
|
127
|
-
)
|
|
121
|
+
""")
|
|
128
122
|
|
|
129
123
|
self.width = self.upper - self.lower
|
|
130
124
|
self.n_bounds = self.width.size
|
|
131
125
|
|
|
132
126
|
def validate_start_point(self, start: ndarray, error_source="Bounds"):
|
|
133
127
|
if self.n_bounds != start.size:
|
|
134
|
-
raise ValueError(
|
|
135
|
-
f"""\n
|
|
128
|
+
raise ValueError(f"""\n
|
|
136
129
|
\r[ {error_source} error ]
|
|
137
130
|
\r>> The number of parameters ({start.size}) does not
|
|
138
131
|
\r>> match the given number of bounds ({self.n_bounds}).
|
|
139
|
-
"""
|
|
140
|
-
)
|
|
132
|
+
""")
|
|
141
133
|
|
|
142
134
|
if not self.inside(start):
|
|
143
|
-
raise ValueError(
|
|
144
|
-
f"""\n
|
|
135
|
+
raise ValueError(f"""\n
|
|
145
136
|
\r[ {error_source} error ]
|
|
146
137
|
\r>> Starting location for the chain is outside specified bounds.
|
|
147
|
-
"""
|
|
148
|
-
)
|
|
138
|
+
""")
|
|
149
139
|
|
|
150
140
|
def reflect(self, theta: ndarray) -> ndarray:
|
|
151
141
|
q, rem = np_divmod(theta - self.lower, self.width)
|
inference/pdf/base.py
CHANGED
|
@@ -39,13 +39,11 @@ class DensityEstimator(ABC):
|
|
|
39
39
|
in the form ``(lower_limit, upper_limit)``.
|
|
40
40
|
"""
|
|
41
41
|
if not 0.0 < fraction < 1.0:
|
|
42
|
-
raise ValueError(
|
|
43
|
-
f"""\n
|
|
42
|
+
raise ValueError(f"""\n
|
|
44
43
|
\r[ {self.__class__.__name__} error ]
|
|
45
44
|
\r>> The 'fraction' argument must have a value greater than
|
|
46
45
|
\r>> zero and less than one, but the value given was {fraction}.
|
|
47
|
-
"""
|
|
48
|
-
)
|
|
46
|
+
""")
|
|
49
47
|
# use the sample to estimate the HDI
|
|
50
48
|
lwr, upr = sample_hdi(self.sample, fraction=fraction)
|
|
51
49
|
# switch variables to the centre and width of the interval
|
inference/pdf/hdi.py
CHANGED
|
@@ -25,37 +25,31 @@ def sample_hdi(sample: ndarray, fraction: float) -> ndarray:
|
|
|
25
25
|
|
|
26
26
|
# verify inputs are valid
|
|
27
27
|
if not 0.0 < fraction < 1.0:
|
|
28
|
-
raise ValueError(
|
|
29
|
-
f"""\n
|
|
28
|
+
raise ValueError(f"""\n
|
|
30
29
|
\r[ sample_hdi error ]
|
|
31
30
|
\r>> The 'fraction' argument must be a float between 0 and 1,
|
|
32
31
|
\r>> but the value given was {fraction}.
|
|
33
|
-
"""
|
|
34
|
-
)
|
|
32
|
+
""")
|
|
35
33
|
|
|
36
34
|
if isinstance(sample, ndarray):
|
|
37
35
|
s = sample.copy()
|
|
38
36
|
elif isinstance(sample, Sequence):
|
|
39
37
|
s = array(sample)
|
|
40
38
|
else:
|
|
41
|
-
raise ValueError(
|
|
42
|
-
f"""\n
|
|
39
|
+
raise ValueError(f"""\n
|
|
43
40
|
\r[ sample_hdi error ]
|
|
44
41
|
\r>> The 'sample' argument should be a numpy.ndarray or a
|
|
45
42
|
\r>> Sequence which can be converted to an array, but
|
|
46
43
|
\r>> instead has type {type(sample)}.
|
|
47
|
-
"""
|
|
48
|
-
)
|
|
44
|
+
""")
|
|
49
45
|
|
|
50
46
|
if s.ndim > 2 or s.ndim == 0:
|
|
51
|
-
raise ValueError(
|
|
52
|
-
f"""\n
|
|
47
|
+
raise ValueError(f"""\n
|
|
53
48
|
\r[ sample_hdi error ]
|
|
54
49
|
\r>> The 'sample' argument should be a numpy.ndarray
|
|
55
50
|
\r>> with either one or two dimensions, but the given
|
|
56
51
|
\r>> array has dimensionality {s.ndim}.
|
|
57
|
-
"""
|
|
58
|
-
)
|
|
52
|
+
""")
|
|
59
53
|
|
|
60
54
|
if s.ndim == 1:
|
|
61
55
|
s.resize([s.size, 1])
|
|
@@ -64,31 +58,25 @@ def sample_hdi(sample: ndarray, fraction: float) -> ndarray:
|
|
|
64
58
|
L = int(fraction * n_samples)
|
|
65
59
|
|
|
66
60
|
if n_samples < 2:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
f"""\n
|
|
61
|
+
raise ValueError(f"""\n
|
|
69
62
|
\r[ sample_hdi error ]
|
|
70
63
|
\r>> The first dimension of the given 'sample' array must
|
|
71
64
|
\r>> have have a length of at least 2.
|
|
72
|
-
"""
|
|
73
|
-
)
|
|
65
|
+
""")
|
|
74
66
|
|
|
75
67
|
# check that we have enough samples to estimate the HDI for the chosen fraction
|
|
76
68
|
if n_samples <= L:
|
|
77
|
-
warn(
|
|
78
|
-
f"""\n
|
|
69
|
+
warn(f"""\n
|
|
79
70
|
\r[ sample_hdi warning ]
|
|
80
71
|
\r>> The given number of samples is insufficient to estimate the interval
|
|
81
72
|
\r>> for the given fraction.
|
|
82
|
-
"""
|
|
83
|
-
)
|
|
73
|
+
""")
|
|
84
74
|
|
|
85
75
|
elif n_samples - L < 20:
|
|
86
|
-
warn(
|
|
87
|
-
f"""\n
|
|
76
|
+
warn(f"""\n
|
|
88
77
|
\r[ sample_hdi warning ]
|
|
89
78
|
\r>> n_samples * (1 - fraction) is small - calculated interval may be inaccurate.
|
|
90
|
-
"""
|
|
91
|
-
)
|
|
79
|
+
""")
|
|
92
80
|
|
|
93
81
|
# check that we have enough samples to estimate the HDI for the chosen fraction
|
|
94
82
|
s.sort(axis=0)
|
inference/pdf/kde.py
CHANGED
|
@@ -51,13 +51,11 @@ class GaussianKDE(DensityEstimator):
|
|
|
51
51
|
self.max_cvs = max_cv_samples
|
|
52
52
|
|
|
53
53
|
if self.sample.size < 3:
|
|
54
|
-
raise ValueError(
|
|
55
|
-
"""\n
|
|
54
|
+
raise ValueError("""\n
|
|
56
55
|
\r[ GaussianKDE error ]
|
|
57
56
|
\r>> Not enough samples were given to estimate the PDF.
|
|
58
57
|
\r>> At least 3 samples are required.
|
|
59
|
-
"""
|
|
60
|
-
)
|
|
58
|
+
""")
|
|
61
59
|
|
|
62
60
|
if bandwidth is None:
|
|
63
61
|
self.h = self.simple_bandwidth_estimator() # very simple bandwidth estimate
|
inference/plotting.py
CHANGED
|
@@ -89,23 +89,19 @@ def matrix_plot(
|
|
|
89
89
|
labels = [f"param {i}" for i in range(N_par)]
|
|
90
90
|
else:
|
|
91
91
|
if len(labels) != N_par:
|
|
92
|
-
raise ValueError(
|
|
93
|
-
"""\n
|
|
92
|
+
raise ValueError("""\n
|
|
94
93
|
\r[ matrix_plot error ]
|
|
95
94
|
\r>> The number of labels given does not match
|
|
96
95
|
\r>> the number of plotted parameters.
|
|
97
|
-
"""
|
|
98
|
-
)
|
|
96
|
+
""")
|
|
99
97
|
|
|
100
98
|
if reference is not None:
|
|
101
99
|
if len(reference) != N_par:
|
|
102
|
-
raise ValueError(
|
|
103
|
-
"""\n
|
|
100
|
+
raise ValueError("""\n
|
|
104
101
|
\r[ matrix_plot error ]
|
|
105
102
|
\r>> The number of reference values given does not match
|
|
106
103
|
\r>> the number of plotted parameters.
|
|
107
|
-
"""
|
|
108
|
-
)
|
|
104
|
+
""")
|
|
109
105
|
# check that given plot style is valid, else default to a histogram
|
|
110
106
|
if plot_style not in ["contour", "hdi", "histogram", "scatter"]:
|
|
111
107
|
plot_style = "contour"
|
|
@@ -115,13 +111,11 @@ def matrix_plot(
|
|
|
115
111
|
|
|
116
112
|
iterable = hasattr(hdi_fractions, "__iter__")
|
|
117
113
|
if not iterable or not all(0 < f < 1 for f in hdi_fractions):
|
|
118
|
-
raise ValueError(
|
|
119
|
-
"""\n
|
|
114
|
+
raise ValueError("""\n
|
|
120
115
|
\r[ matrix_plot error ]
|
|
121
116
|
\r>> The 'hdi_fractions' argument must be given as an
|
|
122
117
|
\r>> iterable of floats, each in the range [0, 1].
|
|
123
|
-
"""
|
|
124
|
-
)
|
|
118
|
+
""")
|
|
125
119
|
|
|
126
120
|
# by default, we suppress axis ticks if there are 6 parameters or more to keep things tidy
|
|
127
121
|
if show_ticks is None:
|
|
@@ -373,10 +367,11 @@ def hdi_plot(
|
|
|
373
367
|
x: ndarray,
|
|
374
368
|
sample: ndarray,
|
|
375
369
|
intervals: Sequence[float] = (0.65, 0.95),
|
|
376
|
-
|
|
370
|
+
color: str = "C0",
|
|
377
371
|
axis=None,
|
|
378
|
-
|
|
379
|
-
|
|
372
|
+
plot_mean=True,
|
|
373
|
+
labels=True,
|
|
374
|
+
interval_alpha: Sequence[float] = None,
|
|
380
375
|
):
|
|
381
376
|
"""
|
|
382
377
|
Plot highest-density intervals for a given sample of model realisations.
|
|
@@ -389,27 +384,33 @@ def hdi_plot(
|
|
|
389
384
|
where ``n`` is the number of samples.
|
|
390
385
|
|
|
391
386
|
:param intervals: \
|
|
392
|
-
|
|
387
|
+
The fractions of the total probability contained in each interval which is to be
|
|
388
|
+
plotted as a Sequence of floats. All given values must be in the range [0, 1].
|
|
393
389
|
|
|
394
|
-
:param str
|
|
395
|
-
The
|
|
396
|
-
a valid
|
|
390
|
+
:param str color: \
|
|
391
|
+
The color to be used for plotting the intervals. Must be the name of
|
|
392
|
+
a valid ``matplotlib`` color.
|
|
397
393
|
|
|
398
394
|
:param axis: \
|
|
399
395
|
A ``matplotlib.pyplot`` axis object which will be used to plot the intervals.
|
|
400
396
|
|
|
401
|
-
:param bool
|
|
402
|
-
If ``True``,
|
|
403
|
-
|
|
397
|
+
:param bool plot_mean: \
|
|
398
|
+
If ``True``, the mean of the samples is also plotted on top of the
|
|
399
|
+
highest-density intervals.
|
|
404
400
|
|
|
405
|
-
:param
|
|
406
|
-
|
|
407
|
-
|
|
401
|
+
:param bool labels: \
|
|
402
|
+
If ``True``, then labels will be assigned to each plot element such that they
|
|
403
|
+
appear in the legend when using ``matplotlib.pyplot.legend``.
|
|
404
|
+
|
|
405
|
+
:param interval_alpha: \
|
|
406
|
+
A sequence of floats in the range [0, 1] specifying the 'alpha' value (which sets
|
|
407
|
+
the color transparency) which is used when coloring the intervals for each given
|
|
408
|
+
probability fraction.
|
|
408
409
|
"""
|
|
409
410
|
# order the intervals from highest to lowest
|
|
410
411
|
intervals = array(intervals)
|
|
411
|
-
intervals.
|
|
412
|
-
intervals = intervals[
|
|
412
|
+
sorter = intervals.argsort()
|
|
413
|
+
intervals = intervals[sorter]
|
|
413
414
|
|
|
414
415
|
# check that all the intervals are valid:
|
|
415
416
|
if not all((intervals > 0.0) & (intervals < 1.0)):
|
|
@@ -425,32 +426,56 @@ def hdi_plot(
|
|
|
425
426
|
|
|
426
427
|
# sort the sample data
|
|
427
428
|
s.sort(axis=0)
|
|
428
|
-
n = s.shape[0]
|
|
429
|
-
|
|
430
|
-
if colormap in colormaps:
|
|
431
|
-
cmap = colormaps[colormap]
|
|
432
|
-
else:
|
|
433
|
-
cmap = colormaps["Blues"]
|
|
434
|
-
warn(f"'{colormap}' is not a valid colormap from matplotlib.colormaps")
|
|
435
429
|
|
|
436
|
-
if
|
|
430
|
+
if interval_alpha is None:
|
|
437
431
|
# construct the colors for each interval
|
|
438
|
-
lwr = 0.
|
|
439
|
-
upr =
|
|
440
|
-
|
|
432
|
+
lwr = 0.15
|
|
433
|
+
upr = 0.9
|
|
434
|
+
interval_alpha = (upr - lwr) * (1 - intervals) + lwr
|
|
435
|
+
else:
|
|
436
|
+
interval_alpha = array(interval_alpha)
|
|
437
|
+
assert interval_alpha.ndim == 1
|
|
438
|
+
assert interval_alpha.size == interval_alpha.size
|
|
439
|
+
interval_alpha = interval_alpha[sorter]
|
|
441
440
|
|
|
442
|
-
|
|
441
|
+
valid_alpha = (interval_alpha >= 0.0) & (interval_alpha <= 1.0)
|
|
442
|
+
if not valid_alpha:
|
|
443
|
+
raise ValueError("Given 'alpha' values must be in the range [0, 1]")
|
|
443
444
|
|
|
444
445
|
# if not plotting axis is given, then use default pyplot
|
|
445
446
|
if axis is None:
|
|
446
447
|
_, axis = plt.subplots()
|
|
447
448
|
|
|
448
|
-
|
|
449
|
-
|
|
449
|
+
if plot_mean:
|
|
450
|
+
lab = "mean" if labels else None
|
|
451
|
+
axis.plot(x, s.mean(axis=0), color=color, lw=2, label=lab)
|
|
452
|
+
|
|
453
|
+
# Calculate HDI for each interval fraction
|
|
454
|
+
lower = []
|
|
455
|
+
upper = []
|
|
456
|
+
for frac in intervals:
|
|
450
457
|
lwr, upr = sample_hdi(s, fraction=frac)
|
|
451
|
-
|
|
452
|
-
|
|
458
|
+
lower.append(lwr)
|
|
459
|
+
upper.append(upr)
|
|
460
|
+
|
|
461
|
+
lab = f"{int(100 * intervals[0])}% HDI" if labels else None
|
|
462
|
+
axis.fill_between(
|
|
463
|
+
x, lower[0], upper[0], color=color, label=lab, alpha=interval_alpha[0]
|
|
464
|
+
)
|
|
453
465
|
|
|
466
|
+
for i in range(len(intervals) - 1):
|
|
467
|
+
lab = f"{int(100 * intervals[i + 1])}% HDI" if labels else None
|
|
468
|
+
axis.fill_between(
|
|
469
|
+
x,
|
|
470
|
+
lower[i + 1],
|
|
471
|
+
lower[i],
|
|
472
|
+
color=color,
|
|
473
|
+
label=lab,
|
|
474
|
+
alpha=interval_alpha[i + 1],
|
|
475
|
+
)
|
|
476
|
+
axis.fill_between(
|
|
477
|
+
x, upper[i], upper[i + 1], color=color, alpha=interval_alpha[i + 1]
|
|
478
|
+
)
|
|
454
479
|
return axis
|
|
455
480
|
|
|
456
481
|
|
inference/priors.py
CHANGED
|
@@ -20,13 +20,11 @@ class BasePrior(ABC):
|
|
|
20
20
|
n_parameters: int,
|
|
21
21
|
class_name="BasePrior",
|
|
22
22
|
) -> list[int]:
|
|
23
|
-
indices_type_error = TypeError(
|
|
24
|
-
f"""\n
|
|
23
|
+
indices_type_error = TypeError(f"""\n
|
|
25
24
|
\r[ {class_name} error ]
|
|
26
25
|
\r>> 'variable_inds' argument of {class_name} must be
|
|
27
26
|
\r>> given as an integer or list of integers
|
|
28
|
-
"""
|
|
29
|
-
)
|
|
27
|
+
""")
|
|
30
28
|
|
|
31
29
|
if not isinstance(variable_inds, (int, Iterable)):
|
|
32
30
|
raise indices_type_error
|
|
@@ -41,22 +39,18 @@ class BasePrior(ABC):
|
|
|
41
39
|
variable_inds = list(variable_inds)
|
|
42
40
|
|
|
43
41
|
if n_parameters != len(variable_inds):
|
|
44
|
-
raise ValueError(
|
|
45
|
-
f"""\n
|
|
42
|
+
raise ValueError(f"""\n
|
|
46
43
|
\r[ {class_name} error ]
|
|
47
44
|
\r>> The total number of variables specified via the 'variable_indices' argument
|
|
48
45
|
\r>> is inconsistent with the number specified by the other arguments.
|
|
49
|
-
"""
|
|
50
|
-
)
|
|
46
|
+
""")
|
|
51
47
|
|
|
52
48
|
if len(variable_inds) != len(set(variable_inds)):
|
|
53
|
-
raise ValueError(
|
|
54
|
-
f"""\n
|
|
49
|
+
raise ValueError(f"""\n
|
|
55
50
|
\r[ {class_name} error ]
|
|
56
51
|
\r>> All integers given via the 'variable_indices' must be unique.
|
|
57
52
|
\r>> Two or more of the given integers are duplicates.
|
|
58
|
-
"""
|
|
59
|
-
)
|
|
53
|
+
""")
|
|
60
54
|
|
|
61
55
|
return variable_inds
|
|
62
56
|
|
|
@@ -101,13 +95,11 @@ class BasePrior(ABC):
|
|
|
101
95
|
:returns: \
|
|
102
96
|
A single sample from the prior distribution as a 1D ``numpy.ndarray``.
|
|
103
97
|
"""
|
|
104
|
-
raise NotImplementedError(
|
|
105
|
-
f"""\n
|
|
98
|
+
raise NotImplementedError(f"""\n
|
|
106
99
|
\r[ {self.__class__.__name__} error ]
|
|
107
100
|
\r>> 'sample' is an optional method for classes inheriting from
|
|
108
101
|
\r>> 'BasePrior', and has not been implemented for '{self.__class__.__name__}'.
|
|
109
|
-
"""
|
|
110
|
-
)
|
|
102
|
+
""")
|
|
111
103
|
|
|
112
104
|
|
|
113
105
|
class JointPrior(BasePrior):
|
|
@@ -125,13 +117,11 @@ class JointPrior(BasePrior):
|
|
|
125
117
|
|
|
126
118
|
def __init__(self, components: list[BasePrior], n_variables: int):
|
|
127
119
|
if not all(isinstance(c, BasePrior) for c in components):
|
|
128
|
-
raise TypeError(
|
|
129
|
-
"""\n
|
|
120
|
+
raise TypeError("""\n
|
|
130
121
|
\r[ JointPrior error ]
|
|
131
122
|
\r>> The sequence of prior objects passed to the 'components' argument
|
|
132
123
|
\r>> of 'JointPrior' must be instances of a subclass of 'BasePrior'.
|
|
133
|
-
"""
|
|
134
|
-
)
|
|
124
|
+
""")
|
|
135
125
|
|
|
136
126
|
# Combine any prior components which are of the same type
|
|
137
127
|
self.components = []
|
|
@@ -146,34 +136,28 @@ class JointPrior(BasePrior):
|
|
|
146
136
|
self.prior_variables = []
|
|
147
137
|
for var in chain(*[c.variables for c in self.components]):
|
|
148
138
|
if var in self.prior_variables:
|
|
149
|
-
raise ValueError(
|
|
150
|
-
f"""\n
|
|
139
|
+
raise ValueError(f"""\n
|
|
151
140
|
\r[ JointPrior error ]
|
|
152
141
|
\r>> Variable index '{var}' appears more than once in the prior
|
|
153
142
|
\r>> objects passed to the 'components' argument of 'JointPrior'.
|
|
154
|
-
"""
|
|
155
|
-
)
|
|
143
|
+
""")
|
|
156
144
|
self.prior_variables.append(var)
|
|
157
145
|
|
|
158
146
|
if len(self.prior_variables) != n_variables:
|
|
159
|
-
raise ValueError(
|
|
160
|
-
f"""\n
|
|
147
|
+
raise ValueError(f"""\n
|
|
161
148
|
\r[ JointPrior error ]
|
|
162
149
|
\r>> The total number of variables specified across the various prior
|
|
163
150
|
\r>> components ({len(self.prior_variables)}) does not match the number
|
|
164
151
|
\r>> specified in the 'n_variables' argument ({n_variables}).
|
|
165
|
-
"""
|
|
166
|
-
)
|
|
152
|
+
""")
|
|
167
153
|
|
|
168
154
|
if not all(0 <= i < n_variables for i in self.prior_variables):
|
|
169
|
-
raise ValueError(
|
|
170
|
-
"""\n
|
|
155
|
+
raise ValueError("""\n
|
|
171
156
|
\r[ JointPrior error ]
|
|
172
157
|
\r>> All variable indices specified across the various prior
|
|
173
158
|
\r>> objects passed to the 'components' argument of 'JointPrior'
|
|
174
159
|
\r>> must have values in the range [0, n_variables - 1].
|
|
175
|
-
"""
|
|
176
|
-
)
|
|
160
|
+
""")
|
|
177
161
|
|
|
178
162
|
self.n_variables = n_variables
|
|
179
163
|
|
|
@@ -419,12 +403,10 @@ class UniformPrior(BasePrior):
|
|
|
419
403
|
self.grad = zeros(self.n_params)
|
|
420
404
|
|
|
421
405
|
if (self.upper <= self.lower).any():
|
|
422
|
-
raise ValueError(
|
|
423
|
-
"""\n
|
|
406
|
+
raise ValueError("""\n
|
|
424
407
|
\r[ UniformPrior error ]
|
|
425
408
|
\r>> All values in 'lower' must be less than the corresponding values in 'upper'
|
|
426
|
-
"""
|
|
427
|
-
)
|
|
409
|
+
""")
|
|
428
410
|
|
|
429
411
|
self.variables = self.validate_variable_indices(
|
|
430
412
|
variable_inds=variable_indices,
|
|
@@ -498,55 +480,45 @@ def validate_prior_parameters(
|
|
|
498
480
|
param = atleast_1d(param).astype(float)
|
|
499
481
|
|
|
500
482
|
if not isinstance(param, ndarray):
|
|
501
|
-
raise TypeError(
|
|
502
|
-
f"""\n
|
|
483
|
+
raise TypeError(f"""\n
|
|
503
484
|
\r[ {class_name} error ]
|
|
504
485
|
\r>> Argument '{param_name}' should be an instance of a numpy.ndarray,
|
|
505
486
|
\r>> but instead has type:
|
|
506
487
|
\r>> {type(param)}
|
|
507
|
-
"""
|
|
508
|
-
)
|
|
488
|
+
""")
|
|
509
489
|
|
|
510
490
|
if param.ndim != 1:
|
|
511
|
-
raise ValueError(
|
|
512
|
-
f"""\n
|
|
491
|
+
raise ValueError(f"""\n
|
|
513
492
|
\r[ {class_name} error ]
|
|
514
493
|
\r>> Argument '{param_name}' should be a 1D numpy.ndarray,
|
|
515
494
|
\r>> but has {param.ndim} dimensions and shape {param.shape}.
|
|
516
|
-
"""
|
|
517
|
-
)
|
|
495
|
+
""")
|
|
518
496
|
|
|
519
497
|
if not isfinite(param).all():
|
|
520
|
-
raise ValueError(
|
|
521
|
-
f"""\n
|
|
498
|
+
raise ValueError(f"""\n
|
|
522
499
|
\r[ {class_name} error ]
|
|
523
500
|
\r>> Argument '{param_name}' contains non-finite values.
|
|
524
|
-
"""
|
|
525
|
-
)
|
|
501
|
+
""")
|
|
526
502
|
|
|
527
503
|
if param_name in require_positive:
|
|
528
504
|
if not (param > 0.0).all():
|
|
529
|
-
raise ValueError(
|
|
530
|
-
f"""\n
|
|
505
|
+
raise ValueError(f"""\n
|
|
531
506
|
\r[ {class_name} error ]
|
|
532
507
|
\r>> All values given in '{param_name}' must be greater than zero.
|
|
533
|
-
"""
|
|
534
|
-
)
|
|
508
|
+
""")
|
|
535
509
|
|
|
536
510
|
validated_params.append(param)
|
|
537
511
|
|
|
538
512
|
# check all inputs are the same size by collecting their sizes in a set
|
|
539
513
|
if len({param.size for param in validated_params}) != 1:
|
|
540
|
-
raise ValueError(
|
|
541
|
-
f"""\n
|
|
514
|
+
raise ValueError(f"""\n
|
|
542
515
|
\r[ {class_name} error ]
|
|
543
516
|
\r>> Arguments
|
|
544
517
|
\r>> {[param_name for param_name, _ in params]}
|
|
545
518
|
\r>> must all be arrays of equal size, but instead have sizes
|
|
546
519
|
\r>> {[param.size for param in validated_params]}
|
|
547
520
|
\r>> respectively.
|
|
548
|
-
"""
|
|
549
|
-
)
|
|
521
|
+
""")
|
|
550
522
|
|
|
551
523
|
return validated_params
|
|
552
524
|
|