cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.2.dist-info/RECORD +0 -12
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Implements Gauss-Hermite quadrature."""
|
|
2
|
+
|
|
3
|
+
from itertools import product
|
|
4
|
+
from typing import NamedTuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from jax.typing import ArrayLike
|
|
8
|
+
from numpy.polynomial.hermite_e import hermegauss
|
|
9
|
+
|
|
10
|
+
from cuthbertlib.quadrature import cubature
|
|
11
|
+
from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
|
|
12
|
+
|
|
13
|
+
__all__ = ["weights", "GaussHermiteQuadrature"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GaussHermiteQuadrature(NamedTuple):
|
|
17
|
+
"""Gauss-Hermite quadrature.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
wm: The mean weights.
|
|
21
|
+
wc: The covariance weights.
|
|
22
|
+
xi: The sigma points.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
wm: ArrayLike
|
|
26
|
+
wc: ArrayLike
|
|
27
|
+
xi: ArrayLike
|
|
28
|
+
|
|
29
|
+
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
|
|
30
|
+
"""Get the sigma points.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
m: The mean.
|
|
34
|
+
chol: The Cholesky factor of the covariance.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
SigmaPoints: The sigma points.
|
|
38
|
+
"""
|
|
39
|
+
return cubature.get_sigma_points(m, chol, self.xi, self.wm, self.wc)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def weights(n_dim: int, order: int = 3) -> Quadrature:
|
|
43
|
+
"""Computes the weights associated with the Gauss-Hermite quadrature method.
|
|
44
|
+
|
|
45
|
+
The Hermite polynomial is in the probabilist's version.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
n_dim: Dimensionality of the problem.
|
|
49
|
+
order: The order of Hermite polynomial. Defaults to 3.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
The quadrature object with the weights and sigma-points.
|
|
53
|
+
|
|
54
|
+
References:
|
|
55
|
+
Simo Särkkä. *Bayesian Filtering and Smoothing.*
|
|
56
|
+
In: Cambridge University Press 2013.
|
|
57
|
+
"""
|
|
58
|
+
x, w = hermegauss(order)
|
|
59
|
+
xn = np.array(list(product(*(x,) * n_dim)))
|
|
60
|
+
wn = np.prod(np.array(list(product(*(w,) * n_dim))), 1)
|
|
61
|
+
wn /= np.sqrt(2 * np.pi) ** n_dim
|
|
62
|
+
return GaussHermiteQuadrature(wm=wn, wc=wn, xi=xn)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Implements quadrature-based linearization of conditional moments and functional."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax import vmap
|
|
7
|
+
from jax.scipy.linalg import cho_solve
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.linalg import tria
|
|
10
|
+
from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
|
|
11
|
+
from cuthbertlib.quadrature.utils import cholesky_update_many
|
|
12
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
13
|
+
|
|
14
|
+
__all__ = ["conditional_moments", "functional"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def conditional_moments(
|
|
18
|
+
mean_fn: Callable[[ArrayLike], Array],
|
|
19
|
+
cov_fn: Callable[[ArrayLike], Array],
|
|
20
|
+
m: ArrayLike,
|
|
21
|
+
cov: ArrayLike,
|
|
22
|
+
quadrature: Quadrature,
|
|
23
|
+
mode: str = "covariance",
|
|
24
|
+
) -> tuple[Array, Array, Array]:
|
|
25
|
+
r"""Linearizes the conditional mean and covariance of a Gaussian distribution.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
mean_fn: The mean function $\mathbb{E}[Y \mid x] =$ `mean_fn(x)`.
|
|
29
|
+
cov_fn: The covariance function $\mathbb{C}[Y \mid x] =$ `cov_fn(x)`.
|
|
30
|
+
m: The mean of the Gaussian distribution.
|
|
31
|
+
cov: The covariance of the Gaussian distribution.
|
|
32
|
+
quadrature: The quadrature object with the weights and sigma-points.
|
|
33
|
+
mode: The mode of the covariance. Default is 'covariance', which means that cov
|
|
34
|
+
and cov_fn are given as covariance matrices.
|
|
35
|
+
Otherwise, the Cholesky factor of the covariances are given.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A, b, Q where A, b are the linearized model parameters and Q is either given as
|
|
39
|
+
a full covariance matrix or as a square root factor depending on
|
|
40
|
+
the `mode`.
|
|
41
|
+
"""
|
|
42
|
+
if mode == "covariance":
|
|
43
|
+
chol = jnp.linalg.cholesky(cov)
|
|
44
|
+
else:
|
|
45
|
+
chol = cov
|
|
46
|
+
x_pts: SigmaPoints = quadrature.get_sigma_points(m, chol)
|
|
47
|
+
|
|
48
|
+
f_pts = SigmaPoints(vmap(mean_fn)(x_pts.points), x_pts.wm, x_pts.wc)
|
|
49
|
+
Psi_x = x_pts.covariance(f_pts)
|
|
50
|
+
|
|
51
|
+
A = cho_solve((chol, True), Psi_x).T
|
|
52
|
+
b = f_pts.mean - A @ m
|
|
53
|
+
if mode != "covariance":
|
|
54
|
+
# This can probably be abstracted better.
|
|
55
|
+
sqrt_Phi = f_pts.sqrt
|
|
56
|
+
|
|
57
|
+
chol_pts = vmap(cov_fn)(x_pts.points)
|
|
58
|
+
temp = jnp.sqrt(x_pts.wc[:, None, None]) * chol_pts
|
|
59
|
+
|
|
60
|
+
# concatenate the blocks properly, it's a bit urk, but what can you do...
|
|
61
|
+
temp = jnp.transpose(temp, [1, 0, 2]).reshape(temp.shape[1], -1)
|
|
62
|
+
chol_Q = tria(jnp.concatenate([sqrt_Phi, temp], axis=1))
|
|
63
|
+
chol_Q = cholesky_update_many(chol_Q, (A @ chol).T, -1.0)
|
|
64
|
+
return A, b, chol_Q
|
|
65
|
+
|
|
66
|
+
V_pts = vmap(cov_fn)(x_pts.points)
|
|
67
|
+
v_f = jnp.sum(x_pts.wc[:, None, None] * V_pts, 0)
|
|
68
|
+
|
|
69
|
+
Phi = f_pts.covariance()
|
|
70
|
+
Q = Phi + v_f - A @ cov @ A.T
|
|
71
|
+
|
|
72
|
+
return A, b, 0.5 * (Q + Q.T)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def functional(
|
|
76
|
+
fn: Callable[[ArrayLike], Array],
|
|
77
|
+
S: ArrayLike,
|
|
78
|
+
m: ArrayLike,
|
|
79
|
+
cov: ArrayLike,
|
|
80
|
+
quadrature: Quadrature,
|
|
81
|
+
mode: str = "covariance",
|
|
82
|
+
) -> tuple[Array, Array, Array]:
|
|
83
|
+
r"""Linearizes a nonlinear function of a Gaussian distribution.
|
|
84
|
+
|
|
85
|
+
For a given Gaussian distribution $p(x) = N(x \mid m, P)$,
|
|
86
|
+
and $Y = f(X) + \epsilon$, where $\epsilon$ is a zero-mean Gaussian noise
|
|
87
|
+
with covariance S, this function computes an approximation $Y = A X + b + \epsilon$
|
|
88
|
+
using the sigma points method given by get_sigma_points.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
fn: The function $Y = f(X) + N(0, S)$.
|
|
92
|
+
Because the function is linearized, the function should be vectorized.
|
|
93
|
+
S: The covariance of the noise.
|
|
94
|
+
m: The mean of the Gaussian distribution.
|
|
95
|
+
cov: The covariance of the Gaussian distribution.
|
|
96
|
+
quadrature: The quadrature object with the weights and sigma-points.
|
|
97
|
+
mode: The mode of the covariance. Default is 'covariance', which means that cov
|
|
98
|
+
and cov_fn are given as covariance matrices. Otherwise, the Cholesky factor
|
|
99
|
+
of the covariances are given.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
A, b, Q: The linearized model parameters $Y = A X + b + N(0, Q)$.
|
|
103
|
+
Q is either given as a full covariance matrix or as a square root factor depending on the `mode`.
|
|
104
|
+
|
|
105
|
+
Notes:
|
|
106
|
+
We do not support non-additive noise in this method.
|
|
107
|
+
If you have a non-additive noise, you should use the `conditional_moments` or
|
|
108
|
+
the Taylor linearization method.
|
|
109
|
+
Another solution is to form the covariance function using the quadrature method
|
|
110
|
+
itself. For example, if you have a function $f(x, q)$, where $q$ is a zero-mean
|
|
111
|
+
random variable with covariance `S`,
|
|
112
|
+
you can form the mean and covariance function as follows:
|
|
113
|
+
|
|
114
|
+
```python
|
|
115
|
+
def linearize_q_part(x):
|
|
116
|
+
n_dim = S.shape[0]
|
|
117
|
+
m_q = jnp.zeros(n_dim)
|
|
118
|
+
A, b, Q = functional(lambda x: f(x, q_sigma_points.points), 0. * S, m_q, S, quadrature, mode)
|
|
119
|
+
return A, b, Q
|
|
120
|
+
|
|
121
|
+
def cov_fn(x):
|
|
122
|
+
A, b, Q = linearize_q_part(x)
|
|
123
|
+
return Q + A @ S @ A.T
|
|
124
|
+
|
|
125
|
+
def mean_fn(x):
|
|
126
|
+
A, b, Q = linearize_q_part(x)
|
|
127
|
+
m_q = jnp.zeros(n_dim)
|
|
128
|
+
return b + f(x, m_q)
|
|
129
|
+
```
|
|
130
|
+
|
|
131
|
+
This technique is a bit wasteful due to our current separation of duties between
|
|
132
|
+
the mean and covariance functions, but as we develop the library further, we
|
|
133
|
+
will provide a more elegant solution.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# make the equivalent conditional_moments model
|
|
137
|
+
def mean_fn(x):
|
|
138
|
+
return fn(x)
|
|
139
|
+
|
|
140
|
+
def cov_fn(x):
|
|
141
|
+
return jnp.asarray(S)
|
|
142
|
+
|
|
143
|
+
return conditional_moments(mean_fn, cov_fn, m, cov, quadrature, mode)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Implements unscented quadrature."""
|
|
2
|
+
|
|
3
|
+
from typing import NamedTuple
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.quadrature.common import SigmaPoints
|
|
8
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
9
|
+
|
|
10
|
+
__all__ = ["weights", "UnscentedQuadrature"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class UnscentedQuadrature(NamedTuple):
|
|
14
|
+
"""Unscented quadrature.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
wm: The mean weights.
|
|
18
|
+
wc: The covariance weights.
|
|
19
|
+
lamda: The lambda parameter.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
wm: Array
|
|
23
|
+
wc: Array
|
|
24
|
+
lamda: float
|
|
25
|
+
|
|
26
|
+
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
|
|
27
|
+
"""Get the sigma points.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
m: The mean.
|
|
31
|
+
chol: The Cholesky factor of the covariance.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
SigmaPoints: The sigma points.
|
|
35
|
+
"""
|
|
36
|
+
m = jnp.asarray(m)
|
|
37
|
+
chol = jnp.asarray(chol)
|
|
38
|
+
|
|
39
|
+
n_dim = m.shape[0]
|
|
40
|
+
scaled_chol = jnp.sqrt(n_dim + self.lamda) * chol
|
|
41
|
+
|
|
42
|
+
zeros = jnp.zeros((1, n_dim))
|
|
43
|
+
sigma_points = m[None, :] + jnp.concatenate(
|
|
44
|
+
[zeros, scaled_chol.T, -scaled_chol.T], axis=0
|
|
45
|
+
)
|
|
46
|
+
return SigmaPoints(sigma_points, self.wm, self.wc)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def weights(
|
|
50
|
+
n_dim: int, alpha: float = 0.5, beta: float = 2.0, kappa: float | None = None
|
|
51
|
+
) -> UnscentedQuadrature:
|
|
52
|
+
"""Computes the weights associated with the unscented cubature method.
|
|
53
|
+
|
|
54
|
+
The number of sigma-points is 2 * n_dim.
|
|
55
|
+
This method is also known as the Unscented Transform, and generalizes the
|
|
56
|
+
`cubature.py` weights: the cubature method is a special case of the unscented
|
|
57
|
+
for the parameters `alpha=1.0`, `beta=0.0`, `kappa=0.0`.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
n_dim: Dimension of the space.
|
|
61
|
+
alpha: Parameter of the unscented transform, default is 0.5.
|
|
62
|
+
beta: Parameter of the unscented transform, default is 2.0.
|
|
63
|
+
kappa: Parameter of the unscented transform, default is 3 + n_dim.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
UnscentedQuadrature: The quadrature object with the weights and sigma-points.
|
|
67
|
+
|
|
68
|
+
References:
|
|
69
|
+
- https://groups.seas.harvard.edu/courses/cs281/papers/unscented.pdf
|
|
70
|
+
"""
|
|
71
|
+
if kappa is None:
|
|
72
|
+
kappa = 3.0 + n_dim
|
|
73
|
+
|
|
74
|
+
lamda = alpha**2 * (n_dim + kappa) - n_dim
|
|
75
|
+
wm = jnp.full(2 * n_dim + 1, 1 / (2 * (n_dim + lamda)))
|
|
76
|
+
|
|
77
|
+
wm = wm.at[0].set(lamda / (n_dim + lamda))
|
|
78
|
+
wc = wm.at[0].set(lamda / (n_dim + lamda) + (1 - alpha**2 + beta))
|
|
79
|
+
return UnscentedQuadrature(wm=wm, wc=wc, lamda=lamda)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Utility functions (Cholesky updating) for quadrature."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
7
|
+
|
|
8
|
+
__all__ = ["cholesky_update_many"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def cholesky_update_many(
|
|
12
|
+
chol_init: ArrayLike, update_vectors: ArrayLike, multiplier: float
|
|
13
|
+
) -> Array:
|
|
14
|
+
r"""Update the Cholesky decomposition of a matrix with multiple update vectors.
|
|
15
|
+
|
|
16
|
+
In mathematical terms, we compute :math:`A + \sum_{i=1}^{n} \alpha v_i v_i^T`
|
|
17
|
+
where :math:`A` is the original matrix, :math:`v_i` are the update vectors and
|
|
18
|
+
:math:`\alpha` is the multiplier.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
chol_init: Initial Cholesky decomposition of the matrix, :math:`A`.
|
|
22
|
+
update_vectors: Update vectors, :math:`v_i`.
|
|
23
|
+
multiplier: The multiplier, :math:`\alpha`.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The updated Cholesky decomposition of the matrix.
|
|
27
|
+
|
|
28
|
+
Notes:
|
|
29
|
+
If the updated matrix does not correspond to a positive definite matrix, the
|
|
30
|
+
function has undefined behaviour. It is the responsibility of the caller to
|
|
31
|
+
ensure that the updated matrix is positive definite as we cannot check this at runtime.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def body(chol, update_vector):
|
|
35
|
+
res = _cholesky_update(chol, update_vector, multiplier=multiplier)
|
|
36
|
+
return res, None
|
|
37
|
+
|
|
38
|
+
final_chol, _ = jax.lax.scan(body, jnp.asarray(chol_init), update_vectors)
|
|
39
|
+
return final_chol
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _set_diagonal(x: Array, y: Array) -> Array:
|
|
43
|
+
N, _ = x.shape
|
|
44
|
+
i, j = jnp.diag_indices(N)
|
|
45
|
+
return x.at[i, j].set(y)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _set_triu(x: Array, val: ArrayLike) -> Array:
|
|
49
|
+
N, _ = x.shape
|
|
50
|
+
i = jnp.triu_indices(N, 1)
|
|
51
|
+
return x.at[i].set(val)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _cholesky_update(
|
|
55
|
+
chol: ArrayLike, update_vector: ArrayLike, multiplier: float = 1.0
|
|
56
|
+
) -> Array:
|
|
57
|
+
chol = jnp.asarray(chol)
|
|
58
|
+
chol_diag = jnp.diag(chol)
|
|
59
|
+
|
|
60
|
+
# The algorithm in [1] is implemented as a double for loop. We can treat
|
|
61
|
+
# the inner loop in Algorithm 3.1 as a vector operation, and thus the
|
|
62
|
+
# whole algorithm as a single for loop, and hence can use a `tf.scan`
|
|
63
|
+
# on it.
|
|
64
|
+
|
|
65
|
+
# We use for accumulation omega and b as defined in Algorithm 3.1, since
|
|
66
|
+
# these are updated per iteration.
|
|
67
|
+
|
|
68
|
+
def scan_body(carry, inp):
|
|
69
|
+
_, _, omega, b = carry
|
|
70
|
+
index, diagonal_member, col = inp
|
|
71
|
+
omega_at_index = omega[..., index]
|
|
72
|
+
|
|
73
|
+
# Line 4
|
|
74
|
+
new_diagonal_member = jnp.sqrt(
|
|
75
|
+
jnp.square(diagonal_member) + multiplier / b * jnp.square(omega_at_index)
|
|
76
|
+
)
|
|
77
|
+
# `scaling_factor` is the same as `gamma` on Line 5.
|
|
78
|
+
scaling_factor = jnp.square(diagonal_member) * b + multiplier * jnp.square(
|
|
79
|
+
omega_at_index
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# The following updates are the same as the for loop in lines 6-8.
|
|
83
|
+
omega = omega - (omega_at_index / diagonal_member)[..., None] * col
|
|
84
|
+
new_col = new_diagonal_member[..., None] * (
|
|
85
|
+
col / diagonal_member[..., None]
|
|
86
|
+
+ (multiplier * omega_at_index / scaling_factor)[..., None] * omega
|
|
87
|
+
)
|
|
88
|
+
b = b + multiplier * jnp.square(omega_at_index / diagonal_member)
|
|
89
|
+
return (new_diagonal_member, new_col, omega, b), (
|
|
90
|
+
new_diagonal_member,
|
|
91
|
+
new_col,
|
|
92
|
+
omega,
|
|
93
|
+
b,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# We will scan over the columns.
|
|
97
|
+
chol = chol.T
|
|
98
|
+
|
|
99
|
+
_, (new_diag, new_chol, _, _) = jax.lax.scan(
|
|
100
|
+
scan_body,
|
|
101
|
+
(0.0, jnp.zeros_like(chol[0]), update_vector, 1.0),
|
|
102
|
+
(jnp.arange(0, chol.shape[0]), chol_diag, chol),
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
new_chol = new_chol.T
|
|
106
|
+
new_chol = _set_diagonal(new_chol, new_diag)
|
|
107
|
+
new_chol = _set_triu(new_chol, 0.0)
|
|
108
|
+
new_chol = jnp.where(jnp.isfinite(new_chol), new_chol, 0.0)
|
|
109
|
+
return new_chol
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""Implements killing resampling."""
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
from jax import random
|
|
7
|
+
from jax.scipy.special import logsumexp
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.resampling import multinomial
|
|
10
|
+
from cuthbertlib.resampling.protocols import (
|
|
11
|
+
conditional_resampling_decorator,
|
|
12
|
+
resampling_decorator,
|
|
13
|
+
)
|
|
14
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
15
|
+
|
|
16
|
+
_DESCRIPTION = """
|
|
17
|
+
The Killing resampling is a simple resampling mechanism that checks if
|
|
18
|
+
particles should be replaced or not, based on their weights.
|
|
19
|
+
If they should be replaced, they are replaced by another particle using
|
|
20
|
+
multinomial resampling on residual weights. It presents the benefit of not
|
|
21
|
+
"breaking" trajectories as much as multinomial resampling, and therefore is
|
|
22
|
+
stable in contexts where the trajectories are important (typically when dealing
|
|
23
|
+
with continuous-time models).
|
|
24
|
+
|
|
25
|
+
By construction, it requires the same number of sampled indices `n` as the
|
|
26
|
+
number of particles `logits.shape[0]`.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@partial(resampling_decorator, name="Killing", desc=_DESCRIPTION)
|
|
31
|
+
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
|
|
32
|
+
logits = jnp.asarray(logits)
|
|
33
|
+
key_1, key_2 = random.split(key)
|
|
34
|
+
N = logits.shape[0]
|
|
35
|
+
if n != N:
|
|
36
|
+
raise AssertionError(
|
|
37
|
+
"The number of sampled indices must be equal to the number of "
|
|
38
|
+
f"particles for `Killing` resampling. Got {n} instead of {N}."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
max_logit = jnp.max(logits)
|
|
42
|
+
log_uniforms = jnp.log(random.uniform(key_1, (N,)))
|
|
43
|
+
|
|
44
|
+
survived = log_uniforms <= logits - max_logit
|
|
45
|
+
if_survived = jnp.arange(N) # If the particle survives, it keeps its index
|
|
46
|
+
otherwise = multinomial.resampling(
|
|
47
|
+
key_2, logits, N
|
|
48
|
+
) # otherwise, it is replaced by another particle
|
|
49
|
+
idx = jnp.where(survived, if_survived, otherwise)
|
|
50
|
+
return idx
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION)
|
|
54
|
+
def conditional_resampling(
|
|
55
|
+
key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
|
|
56
|
+
) -> Array:
|
|
57
|
+
# Unconditional resampling
|
|
58
|
+
key_resample, key_shuffle = random.split(key)
|
|
59
|
+
idx = resampling(key_resample, logits, n)
|
|
60
|
+
|
|
61
|
+
# Conditional rolling pivot
|
|
62
|
+
max_logit = jnp.max(logits)
|
|
63
|
+
|
|
64
|
+
pivot_logits = _log1mexp(logits - max_logit)
|
|
65
|
+
pivot_logits -= jnp.log(n)
|
|
66
|
+
pivot_logits = pivot_logits.at[pivot_out].set(-jnp.inf)
|
|
67
|
+
pivot_logits_i = _log1mexp(logsumexp(pivot_logits))
|
|
68
|
+
pivot_logits = pivot_logits.at[pivot_out].set(pivot_logits_i)
|
|
69
|
+
|
|
70
|
+
pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits))
|
|
71
|
+
pivot = random.choice(key_shuffle, n, p=pivot_weights)
|
|
72
|
+
idx = jnp.roll(idx, pivot_in - pivot)
|
|
73
|
+
idx = idx.at[pivot_in].set(pivot_out)
|
|
74
|
+
return idx
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _log1mexp(x: ArrayLike) -> Array:
|
|
78
|
+
# There is probably a better way to do this
|
|
79
|
+
return jnp.log(1 - jnp.exp(x))
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Implements multinomial resampling."""
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import numpy as jnp
|
|
7
|
+
from jax import random
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.resampling.protocols import (
|
|
10
|
+
conditional_resampling_decorator,
|
|
11
|
+
resampling_decorator,
|
|
12
|
+
)
|
|
13
|
+
from cuthbertlib.resampling.utils import inverse_cdf
|
|
14
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
15
|
+
|
|
16
|
+
_DESCRIPTION = """
|
|
17
|
+
This has higher variance than other resampling schemes as it samples from
|
|
18
|
+
the ancestors independently. It should only be used for illustration purposes,
|
|
19
|
+
or if your algorithm *REALLY REALLY* needs independent samples.
|
|
20
|
+
As a rule of thumb, you often don't."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@partial(resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
|
|
24
|
+
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
|
|
25
|
+
# In practice we don't have to sort the generated uniforms, but searchsorted
|
|
26
|
+
# works faster and is more stable if both inputs are sorted, so we use the
|
|
27
|
+
# _sorted_uniforms from N. Chopin, but still use searchsorted instead of his
|
|
28
|
+
# O(N) loop as our code is meant to work on GPU where searchsorted is
|
|
29
|
+
# O(log(N)) anyway.
|
|
30
|
+
# We then permute the indices to enforce exchangeability.
|
|
31
|
+
|
|
32
|
+
key_uniforms, key_shuffle = random.split(key)
|
|
33
|
+
sorted_uniforms = _sorted_uniforms(key_uniforms, n)
|
|
34
|
+
idx = inverse_cdf(sorted_uniforms, logits)
|
|
35
|
+
return random.permutation(key_shuffle, idx)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@partial(conditional_resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
|
|
39
|
+
def conditional_resampling(
|
|
40
|
+
key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
|
|
41
|
+
) -> Array:
|
|
42
|
+
idx = resampling(key, logits, n)
|
|
43
|
+
idx = idx.at[pivot_in].set(pivot_out)
|
|
44
|
+
return idx
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@partial(jax.jit, static_argnames=("n",))
|
|
48
|
+
def _sorted_uniforms(key: Array, n: int) -> Array:
|
|
49
|
+
# This is a small modification of the code from N. Chopin to output sorted
|
|
50
|
+
# log-uniforms *directly*. N. Chopin's code outputs sorted uniforms.
|
|
51
|
+
us = random.uniform(key, (n + 1,))
|
|
52
|
+
z = jnp.cumsum(-jnp.log(us))
|
|
53
|
+
return z[:-1] / z[-1]
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""Shared protocols for resampling algorithms."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.types import Array, ArrayLike, KeyArray
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@runtime_checkable
|
|
11
|
+
class Resampling(Protocol):
|
|
12
|
+
"""Protocol for resampling operations."""
|
|
13
|
+
|
|
14
|
+
def __call__(self, key: KeyArray, logits: ArrayLike, n: int) -> Array:
|
|
15
|
+
"""Computes resampling indices according to given logits.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
key: JAX PRNG key.
|
|
19
|
+
logits: Logits.
|
|
20
|
+
n: Number of indices to sample.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Array of resampling indices.
|
|
24
|
+
"""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@runtime_checkable
|
|
29
|
+
class ConditionalResampling(Protocol):
|
|
30
|
+
"""Protocol for conditional resampling operations."""
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self,
|
|
34
|
+
key: KeyArray,
|
|
35
|
+
logits: ArrayLike,
|
|
36
|
+
n: int,
|
|
37
|
+
pivot_in: int,
|
|
38
|
+
pivot_out: int,
|
|
39
|
+
) -> Array:
|
|
40
|
+
"""Conditional resampling.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
key: JAX PRNG key.
|
|
44
|
+
logits: Log-weights, possibly unnormalized.
|
|
45
|
+
n: Number of indices to sample.
|
|
46
|
+
pivot_in: Index of the particle to keep.
|
|
47
|
+
pivot_out: Value of the output at index `pivot_in`.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Array of size n with indices to use for resampling.
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def resampling_decorator(func: Resampling, name: str, desc: str = "") -> Resampling:
|
|
56
|
+
"""Decorate Resampling function with unified docstring."""
|
|
57
|
+
doc = f"""
|
|
58
|
+
{name} resampling. {desc}
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
key: PRNGKey to use in resampling
|
|
62
|
+
logits: Log-weights, possibly unnormalized.
|
|
63
|
+
n: Number of indices to sample.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Array of size n with indices to use for resampling.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
func.__doc__ = doc
|
|
70
|
+
return jax.jit(func, static_argnames=("n",))
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def conditional_resampling_decorator(
|
|
74
|
+
func: ConditionalResampling, name: str, desc: str = ""
|
|
75
|
+
) -> ConditionalResampling:
|
|
76
|
+
"""Decorate ConditionalResampling function with unified docstring."""
|
|
77
|
+
doc = f"""
|
|
78
|
+
{name} conditional resampling. {desc}
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
key: PRNGKey to use in resampling
|
|
82
|
+
logits: Log-weights, possibly unnormalized.
|
|
83
|
+
n: Number of indices to sample
|
|
84
|
+
pivot_in: Index of the particle to keep
|
|
85
|
+
pivot_out: Value of the output at index `pivot_in`
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Array of size n with indices to use for resampling.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
func.__doc__ = doc
|
|
92
|
+
return jax.jit(func, static_argnames=("n",))
|