inference-tools 0.14.1__py3-none-any.whl → 0.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inference/_version.py +2 -2
- inference/approx/__init__.py +6 -2
- inference/approx/conditional.py +39 -0
- inference/mcmc/base.py +33 -1
- inference/mcmc/gibbs.py +1 -1
- inference/mcmc/hmc/__init__.py +1 -1
- inference/mcmc/hmc/mass.py +41 -4
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.2.dist-info}/METADATA +1 -1
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.2.dist-info}/RECORD +12 -12
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.2.dist-info}/WHEEL +1 -1
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.2.dist-info}/licenses/LICENSE +0 -0
- {inference_tools-0.14.1.dist-info → inference_tools-0.14.2.dist-info}/top_level.txt +0 -0
inference/_version.py
CHANGED
inference/approx/__init__.py
CHANGED
|
@@ -1,3 +1,7 @@
|
|
|
1
|
-
from inference.approx.conditional import
|
|
1
|
+
from inference.approx.conditional import (
|
|
2
|
+
conditional_sample,
|
|
3
|
+
get_conditionals,
|
|
4
|
+
conditional_moments,
|
|
5
|
+
)
|
|
2
6
|
|
|
3
|
-
__all__ = ["conditional_sample", "get_conditionals"]
|
|
7
|
+
__all__ = ["conditional_sample", "get_conditionals", "conditional_moments"]
|
inference/approx/conditional.py
CHANGED
|
@@ -272,3 +272,42 @@ def conditional_sample(
|
|
|
272
272
|
for i in range(n_params):
|
|
273
273
|
samples[:, i] = piecewise_linear_sample(axes[:, i], probs[:, i], n_samples)
|
|
274
274
|
return samples
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def conditional_moments(
|
|
278
|
+
posterior: callable, bounds: list, conditioning_point: ndarray
|
|
279
|
+
) -> tuple[ndarray, ndarray]:
|
|
280
|
+
"""
|
|
281
|
+
Calculate the mean and variance of the 1D conditional distributions of the posterior
|
|
282
|
+
around a given point in the parameter space.
|
|
283
|
+
|
|
284
|
+
:param posterior: \
|
|
285
|
+
A function which returns the posterior log-probability when given a
|
|
286
|
+
numpy ``ndarray`` of the model parameters.
|
|
287
|
+
|
|
288
|
+
:param bounds: \
|
|
289
|
+
A list of length-2 tuples specifying the lower and upper bounds on
|
|
290
|
+
each parameter, in the form ``(lower, upper)``.
|
|
291
|
+
|
|
292
|
+
:param conditioning_point: \
|
|
293
|
+
The point in the parameter space around which the conditional distributions are
|
|
294
|
+
evaluated.
|
|
295
|
+
|
|
296
|
+
:return means, variances: \
|
|
297
|
+
The means and variances of the conditional distributions as a pair of
|
|
298
|
+
1D numpy ``ndarray``.
|
|
299
|
+
"""
|
|
300
|
+
axes, probs = get_conditionals(
|
|
301
|
+
posterior=posterior, bounds=bounds, conditioning_point=conditioning_point
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
grid_size, n_params = probs.shape
|
|
305
|
+
means = zeros(n_params)
|
|
306
|
+
variances = zeros(n_params)
|
|
307
|
+
# integrate to calculate the means and variances
|
|
308
|
+
for i in range(n_params):
|
|
309
|
+
means[i] = simpson(y=axes[:, i] * probs[:, i], x=axes[:, i])
|
|
310
|
+
variances[i] = simpson(
|
|
311
|
+
y=(axes[:, i] - means[i]) ** 2 * probs[:, i], x=axes[:, i]
|
|
312
|
+
)
|
|
313
|
+
return means, variances
|
inference/mcmc/base.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from copy import copy
|
|
3
3
|
from time import time
|
|
4
|
-
from numpy import ndarray
|
|
4
|
+
from numpy import ndarray, isfinite
|
|
5
5
|
from numpy.random import permutation
|
|
6
6
|
|
|
7
7
|
from inference.pdf.base import DensityEstimator
|
|
@@ -262,3 +262,35 @@ class MarkovChain(ABC):
|
|
|
262
262
|
\r>> 'burn' and 'thin' keyword arguments.
|
|
263
263
|
"""
|
|
264
264
|
)
|
|
265
|
+
|
|
266
|
+
def _validate_posterior(self, posterior: callable, start: ndarray):
|
|
267
|
+
if not callable(posterior):
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"""\n
|
|
270
|
+
\r[ {self.__class__.__name__} error ]
|
|
271
|
+
\r>> The given 'posterior' is not a callable object.
|
|
272
|
+
"""
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
prob = posterior(start)
|
|
276
|
+
|
|
277
|
+
if not isinstance(prob, float):
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"""\n
|
|
280
|
+
\r[ {self.__class__.__name__} error ]
|
|
281
|
+
\r>> The given 'posterior' must return a float or a type which
|
|
282
|
+
\r>> derives from float (e.g. numpy.float64), however the returned
|
|
283
|
+
\r>> value has type:
|
|
284
|
+
\r>> {type(prob)}
|
|
285
|
+
"""
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if not isfinite(prob):
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"""\n
|
|
291
|
+
\r[ {self.__class__.__name__} error ]
|
|
292
|
+
\r>> The given 'posterior' must return a finite value for the given
|
|
293
|
+
\r>> 'start' parameter values, but instead returns a value of:
|
|
294
|
+
\r>> {prob}
|
|
295
|
+
"""
|
|
296
|
+
)
|
inference/mcmc/gibbs.py
CHANGED
|
@@ -253,7 +253,7 @@ class MetropolisChain(MarkovChain):
|
|
|
253
253
|
|
|
254
254
|
if posterior is not None:
|
|
255
255
|
self.posterior = posterior
|
|
256
|
-
|
|
256
|
+
self._validate_posterior(posterior=posterior, start=start)
|
|
257
257
|
# if widths are not specified, take 5% of the starting values (unless they're zero)
|
|
258
258
|
if widths is None:
|
|
259
259
|
widths = [v * 0.05 if v != 0 else 1.0 for v in start]
|
inference/mcmc/hmc/__init__.py
CHANGED
|
@@ -87,7 +87,7 @@ class HamiltonianChain(MarkovChain):
|
|
|
87
87
|
start = start if isinstance(start, ndarray) else array(start)
|
|
88
88
|
start = start if start.dtype is float64 else start.astype(float64)
|
|
89
89
|
assert start.ndim == 1
|
|
90
|
-
|
|
90
|
+
self._validate_posterior(posterior=posterior, start=start)
|
|
91
91
|
self.theta = [start]
|
|
92
92
|
self.probs = [self.posterior(start) * self.inv_temp]
|
|
93
93
|
self.leapfrog_steps = [0]
|
inference/mcmc/hmc/mass.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import Union
|
|
|
3
3
|
from numpy import ndarray, sqrt, eye, isscalar
|
|
4
4
|
from numpy.random import Generator
|
|
5
5
|
from numpy.linalg import cholesky
|
|
6
|
-
from scipy.linalg import solve_triangular
|
|
6
|
+
from scipy.linalg import solve_triangular, issymmetric
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class ParticleMass(ABC):
|
|
@@ -37,12 +37,49 @@ class VectorMass(ScalarMass):
|
|
|
37
37
|
assert inv_mass.ndim == 1
|
|
38
38
|
assert inv_mass.size == n_parameters
|
|
39
39
|
|
|
40
|
+
valid_variances = (
|
|
41
|
+
inv_mass.ndim == 1
|
|
42
|
+
and inv_mass.size == n_parameters
|
|
43
|
+
and (inv_mass > 0.0).all()
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
if not valid_variances:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"""\n
|
|
49
|
+
\r[ VectorMass error ]
|
|
50
|
+
\r>> The inverse-mass vector must be a 1D array and have size
|
|
51
|
+
\r>> equal to the given number of model parameters ({n_parameters})
|
|
52
|
+
\r>> and contain only positive values.
|
|
53
|
+
"""
|
|
54
|
+
)
|
|
55
|
+
|
|
40
56
|
|
|
41
57
|
class MatrixMass(ParticleMass):
|
|
42
58
|
def __init__(self, inv_mass: ndarray, n_parameters: int):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
59
|
+
|
|
60
|
+
valid_covariance = (
|
|
61
|
+
inv_mass.ndim == 2
|
|
62
|
+
and inv_mass.shape[0] == inv_mass.shape[1]
|
|
63
|
+
and issymmetric(inv_mass)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if not valid_covariance:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"""\n
|
|
69
|
+
\r[ MatrixMass error ]
|
|
70
|
+
\r>> The given inverse-mass matrix must be a valid covariance matrix,
|
|
71
|
+
\r>> i.e. 2 dimensional, square and symmetric.
|
|
72
|
+
"""
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if inv_mass.shape[0] != n_parameters:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"""\n
|
|
78
|
+
\r[ MatrixMass error ]
|
|
79
|
+
\r>> The dimensions of the given inverse-mass matrix {inv_mass.shape}
|
|
80
|
+
\r>> do not match the given number of model parameters ({n_parameters}).
|
|
81
|
+
"""
|
|
82
|
+
)
|
|
46
83
|
|
|
47
84
|
self.inv_mass = inv_mass
|
|
48
85
|
self.n_parameters = n_parameters
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
inference/__init__.py,sha256=Wheq9bSUF5Y_jAc_w_Avi4WW2kphDK0qHGM6FsIKSxY,275
|
|
2
|
-
inference/_version.py,sha256=
|
|
2
|
+
inference/_version.py,sha256=9gfM7I3EzKUUvEDxnQgfyPI3Ie4vtLtm20hLu-bUsls,513
|
|
3
3
|
inference/likelihoods.py,sha256=0mRn9S7CaX6hNv1fKVeaAFYk50bALvVbyX7E2aH3Bn8,10021
|
|
4
4
|
inference/plotting.py,sha256=vMpRGiZMMlVgAcVaKC2wtvjzVlBmOkC2BM90A3wSwJ8,19194
|
|
5
5
|
inference/posterior.py,sha256=ptPZgzT--ehbpu57nW9GmFuyovFOSmw56HWfuC-8GGA,3584
|
|
6
6
|
inference/priors.py,sha256=zDuIgJTZrqEqkp8rE-aBRlAuqBacR9aC_QNm8jNIYl8,19368
|
|
7
|
-
inference/approx/__init__.py,sha256=
|
|
8
|
-
inference/approx/conditional.py,sha256=
|
|
7
|
+
inference/approx/__init__.py,sha256=ghGGQNpOp4BBE7Ani-ProAN9z3BlIX62ZR6cbQcivS8,193
|
|
8
|
+
inference/approx/conditional.py,sha256=mGM_djkpvlHfTDUoN84mQmIA2f9FUDhp-hdLiwACAzk,10896
|
|
9
9
|
inference/gp/__init__.py,sha256=R4iPgf8TdunkOv_VLwue7Fz3AjGWDTBop58nCmbmMQ0,801
|
|
10
10
|
inference/gp/acquisition.py,sha256=Yr1dshTYwkMIrKYPSwDZDusXXNsOpobrxaympJc5q3g,8158
|
|
11
11
|
inference/gp/covariance.py,sha256=DVN8lAtDjCWXYSsQwhQZxV6RJ8KZeo72unOCjHhTGg0,25919
|
|
@@ -14,22 +14,22 @@ inference/gp/mean.py,sha256=6EJ_OxBi98netl9Rp2Ij7eXdWndGVS-X_g5VWnWMVkk,4084
|
|
|
14
14
|
inference/gp/optimisation.py,sha256=sPhakklWIgg1yEUhUzA-m5vl0kVPvHdcgnQ0OAGT8qs,11763
|
|
15
15
|
inference/gp/regression.py,sha256=10TzqVeUzUkuw8-Cbe4LbxevByTi5iE5QDdRClN7Nhk,25677
|
|
16
16
|
inference/mcmc/__init__.py,sha256=IsEhVSIpZCDNIqgSq_21M6DH6x8F1jJbYWM0e3S3QG4,445
|
|
17
|
-
inference/mcmc/base.py,sha256=
|
|
17
|
+
inference/mcmc/base.py,sha256=ay8P8ypF8_Dj3hIkjmjniiDxc40xVdmRSqCjVAqnMMc,11637
|
|
18
18
|
inference/mcmc/ensemble.py,sha256=JRXu7SBYXN4Y9RzgA6kGUHpZNw4q4A9wf0KOAQdlz0E,15585
|
|
19
|
-
inference/mcmc/gibbs.py,sha256=
|
|
19
|
+
inference/mcmc/gibbs.py,sha256=roXjTdO63fQTtFN6vYZa0yHvWYInFU0qX2hREH0PTgo,24358
|
|
20
20
|
inference/mcmc/parallel.py,sha256=SKLzMP4aqIj1xsxKuByA1lr1GdgIu5pPzVw7hlfXZEQ,14053
|
|
21
21
|
inference/mcmc/pca.py,sha256=NxC81NghGlBQslFVOk2HzpsnCjlEdDnv_w8es4Qe7PU,10695
|
|
22
22
|
inference/mcmc/utilities.py,sha256=YjpK3FvV0Q98jLusrZrvGck-bjm6uZZ1U7HHH3aly8g,6048
|
|
23
|
-
inference/mcmc/hmc/__init__.py,sha256=
|
|
23
|
+
inference/mcmc/hmc/__init__.py,sha256=QgERQDsEqeMBK2xPaLaaCG95zp0jfRaixDERQlwzejs,17731
|
|
24
24
|
inference/mcmc/hmc/epsilon.py,sha256=t2kNi10MSVFXjmAx5zRUARDuPu_yWbwoK2McMuaaAUs,2467
|
|
25
|
-
inference/mcmc/hmc/mass.py,sha256=
|
|
25
|
+
inference/mcmc/hmc/mass.py,sha256=vYuH67H1BnUm3z0OSQ4gVvL060PGTcG9TwB59nTJHCE,3759
|
|
26
26
|
inference/pdf/__init__.py,sha256=gVmQ1HLTab6_oWMQN26A1r7PkqbApaJmBK-c7TIFxjY,270
|
|
27
27
|
inference/pdf/base.py,sha256=Zj5mfFmDqTe5cFz0biBxcvEaxdOUC-SsOUjebUEX7HM,5442
|
|
28
28
|
inference/pdf/hdi.py,sha256=soFw3fKQdzxbGNhU9BvFHdt0uGKfhus3E3vM6L47yhY,4638
|
|
29
29
|
inference/pdf/kde.py,sha256=KSl8y---602MlxoSVH8VknNQYZ2KAOTky50QU3jRw28,12999
|
|
30
30
|
inference/pdf/unimodal.py,sha256=9S05c0hq_rF-MLoDJgUmaJKRdcP8F9_Idj7Ncb6m9q0,6218
|
|
31
|
-
inference_tools-0.14.
|
|
32
|
-
inference_tools-0.14.
|
|
33
|
-
inference_tools-0.14.
|
|
34
|
-
inference_tools-0.14.
|
|
35
|
-
inference_tools-0.14.
|
|
31
|
+
inference_tools-0.14.2.dist-info/licenses/LICENSE,sha256=Y0-EfO5pdxf6d0J6Er13ZSWiPZ2o6kHvM37eRgnJdww,1069
|
|
32
|
+
inference_tools-0.14.2.dist-info/METADATA,sha256=yKkLGfLX7yADQX43ASxGXJsutGTrIVuHllHl5rlO2hg,5400
|
|
33
|
+
inference_tools-0.14.2.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
|
|
34
|
+
inference_tools-0.14.2.dist-info/top_level.txt,sha256=I7bsb71rLtH3yvVH_HSLXUosY2AwCxEG3vctNsEhbEM,10
|
|
35
|
+
inference_tools-0.14.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|