cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. cuthbertlib/discrete/__init__.py +0 -0
  29. cuthbertlib/discrete/filtering.py +49 -0
  30. cuthbertlib/discrete/smoothing.py +35 -0
  31. cuthbertlib/kalman/__init__.py +4 -0
  32. cuthbertlib/kalman/filtering.py +213 -0
  33. cuthbertlib/kalman/generate.py +85 -0
  34. cuthbertlib/kalman/sampling.py +68 -0
  35. cuthbertlib/kalman/smoothing.py +121 -0
  36. cuthbertlib/linalg/__init__.py +7 -0
  37. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  38. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  39. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  40. cuthbertlib/linalg/tria.py +21 -0
  41. cuthbertlib/linearize/__init__.py +7 -0
  42. cuthbertlib/linearize/log_density.py +175 -0
  43. cuthbertlib/linearize/moments.py +94 -0
  44. cuthbertlib/linearize/taylor.py +83 -0
  45. cuthbertlib/quadrature/__init__.py +4 -0
  46. cuthbertlib/quadrature/common.py +102 -0
  47. cuthbertlib/quadrature/cubature.py +73 -0
  48. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  49. cuthbertlib/quadrature/linearize.py +143 -0
  50. cuthbertlib/quadrature/unscented.py +79 -0
  51. cuthbertlib/quadrature/utils.py +109 -0
  52. cuthbertlib/resampling/__init__.py +3 -0
  53. cuthbertlib/resampling/killing.py +79 -0
  54. cuthbertlib/resampling/multinomial.py +53 -0
  55. cuthbertlib/resampling/protocols.py +92 -0
  56. cuthbertlib/resampling/systematic.py +78 -0
  57. cuthbertlib/resampling/utils.py +82 -0
  58. cuthbertlib/smc/__init__.py +0 -0
  59. cuthbertlib/smc/ess.py +24 -0
  60. cuthbertlib/smc/smoothing/__init__.py +0 -0
  61. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  62. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  63. cuthbertlib/smc/smoothing/protocols.py +44 -0
  64. cuthbertlib/smc/smoothing/tracing.py +45 -0
  65. cuthbertlib/stats/__init__.py +0 -0
  66. cuthbertlib/stats/multivariate_normal.py +102 -0
  67. cuthbert-0.0.2.dist-info/RECORD +0 -12
  68. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  69. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
  70. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ """Implements Gauss-Hermite quadrature."""
2
+
3
+ from itertools import product
4
+ from typing import NamedTuple
5
+
6
+ import numpy as np
7
+ from jax.typing import ArrayLike
8
+ from numpy.polynomial.hermite_e import hermegauss
9
+
10
+ from cuthbertlib.quadrature import cubature
11
+ from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
12
+
13
+ __all__ = ["weights", "GaussHermiteQuadrature"]
14
+
15
+
16
+ class GaussHermiteQuadrature(NamedTuple):
17
+ """Gauss-Hermite quadrature.
18
+
19
+ Attributes:
20
+ wm: The mean weights.
21
+ wc: The covariance weights.
22
+ xi: The sigma points.
23
+ """
24
+
25
+ wm: ArrayLike
26
+ wc: ArrayLike
27
+ xi: ArrayLike
28
+
29
+ def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
30
+ """Get the sigma points.
31
+
32
+ Args:
33
+ m: The mean.
34
+ chol: The Cholesky factor of the covariance.
35
+
36
+ Returns:
37
+ SigmaPoints: The sigma points.
38
+ """
39
+ return cubature.get_sigma_points(m, chol, self.xi, self.wm, self.wc)
40
+
41
+
42
+ def weights(n_dim: int, order: int = 3) -> Quadrature:
43
+ """Computes the weights associated with the Gauss-Hermite quadrature method.
44
+
45
+ The Hermite polynomial is in the probabilist's version.
46
+
47
+ Args:
48
+ n_dim: Dimensionality of the problem.
49
+ order: The order of Hermite polynomial. Defaults to 3.
50
+
51
+ Returns:
52
+ The quadrature object with the weights and sigma-points.
53
+
54
+ References:
55
+ Simo Särkkä. *Bayesian Filtering and Smoothing.*
56
+ In: Cambridge University Press 2013.
57
+ """
58
+ x, w = hermegauss(order)
59
+ xn = np.array(list(product(*(x,) * n_dim)))
60
+ wn = np.prod(np.array(list(product(*(w,) * n_dim))), 1)
61
+ wn /= np.sqrt(2 * np.pi) ** n_dim
62
+ return GaussHermiteQuadrature(wm=wn, wc=wn, xi=xn)
@@ -0,0 +1,143 @@
1
+ """Implements quadrature-based linearization of conditional moments and functional."""
2
+
3
+ from typing import Callable
4
+
5
+ import jax.numpy as jnp
6
+ from jax import vmap
7
+ from jax.scipy.linalg import cho_solve
8
+
9
+ from cuthbertlib.linalg import tria
10
+ from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
11
+ from cuthbertlib.quadrature.utils import cholesky_update_many
12
+ from cuthbertlib.types import Array, ArrayLike
13
+
14
+ __all__ = ["conditional_moments", "functional"]
15
+
16
+
17
+ def conditional_moments(
18
+ mean_fn: Callable[[ArrayLike], Array],
19
+ cov_fn: Callable[[ArrayLike], Array],
20
+ m: ArrayLike,
21
+ cov: ArrayLike,
22
+ quadrature: Quadrature,
23
+ mode: str = "covariance",
24
+ ) -> tuple[Array, Array, Array]:
25
+ r"""Linearizes the conditional mean and covariance of a Gaussian distribution.
26
+
27
+ Args:
28
+ mean_fn: The mean function $\mathbb{E}[Y \mid x] =$ `mean_fn(x)`.
29
+ cov_fn: The covariance function $\mathbb{C}[Y \mid x] =$ `cov_fn(x)`.
30
+ m: The mean of the Gaussian distribution.
31
+ cov: The covariance of the Gaussian distribution.
32
+ quadrature: The quadrature object with the weights and sigma-points.
33
+ mode: The mode of the covariance. Default is 'covariance', which means that cov
34
+ and cov_fn are given as covariance matrices.
35
+ Otherwise, the Cholesky factor of the covariances are given.
36
+
37
+ Returns:
38
+ A, b, Q where A, b are the linearized model parameters and Q is either given as
39
+ a full covariance matrix or as a square root factor depending on
40
+ the `mode`.
41
+ """
42
+ if mode == "covariance":
43
+ chol = jnp.linalg.cholesky(cov)
44
+ else:
45
+ chol = cov
46
+ x_pts: SigmaPoints = quadrature.get_sigma_points(m, chol)
47
+
48
+ f_pts = SigmaPoints(vmap(mean_fn)(x_pts.points), x_pts.wm, x_pts.wc)
49
+ Psi_x = x_pts.covariance(f_pts)
50
+
51
+ A = cho_solve((chol, True), Psi_x).T
52
+ b = f_pts.mean - A @ m
53
+ if mode != "covariance":
54
+ # This can probably be abstracted better.
55
+ sqrt_Phi = f_pts.sqrt
56
+
57
+ chol_pts = vmap(cov_fn)(x_pts.points)
58
+ temp = jnp.sqrt(x_pts.wc[:, None, None]) * chol_pts
59
+
60
+ # concatenate the blocks properly, it's a bit urk, but what can you do...
61
+ temp = jnp.transpose(temp, [1, 0, 2]).reshape(temp.shape[1], -1)
62
+ chol_Q = tria(jnp.concatenate([sqrt_Phi, temp], axis=1))
63
+ chol_Q = cholesky_update_many(chol_Q, (A @ chol).T, -1.0)
64
+ return A, b, chol_Q
65
+
66
+ V_pts = vmap(cov_fn)(x_pts.points)
67
+ v_f = jnp.sum(x_pts.wc[:, None, None] * V_pts, 0)
68
+
69
+ Phi = f_pts.covariance()
70
+ Q = Phi + v_f - A @ cov @ A.T
71
+
72
+ return A, b, 0.5 * (Q + Q.T)
73
+
74
+
75
+ def functional(
76
+ fn: Callable[[ArrayLike], Array],
77
+ S: ArrayLike,
78
+ m: ArrayLike,
79
+ cov: ArrayLike,
80
+ quadrature: Quadrature,
81
+ mode: str = "covariance",
82
+ ) -> tuple[Array, Array, Array]:
83
+ r"""Linearizes a nonlinear function of a Gaussian distribution.
84
+
85
+ For a given Gaussian distribution $p(x) = N(x \mid m, P)$,
86
+ and $Y = f(X) + \epsilon$, where $\epsilon$ is a zero-mean Gaussian noise
87
+ with covariance S, this function computes an approximation $Y = A X + b + \epsilon$
88
+ using the sigma points method given by get_sigma_points.
89
+
90
+ Args:
91
+ fn: The function $Y = f(X) + N(0, S)$.
92
+ Because the function is linearized, the function should be vectorized.
93
+ S: The covariance of the noise.
94
+ m: The mean of the Gaussian distribution.
95
+ cov: The covariance of the Gaussian distribution.
96
+ quadrature: The quadrature object with the weights and sigma-points.
97
+ mode: The mode of the covariance. Default is 'covariance', which means that cov
98
+ and cov_fn are given as covariance matrices. Otherwise, the Cholesky factor
99
+ of the covariances are given.
100
+
101
+ Returns:
102
+ A, b, Q: The linearized model parameters $Y = A X + b + N(0, Q)$.
103
+ Q is either given as a full covariance matrix or as a square root factor depending on the `mode`.
104
+
105
+ Notes:
106
+ We do not support non-additive noise in this method.
107
+ If you have a non-additive noise, you should use the `conditional_moments` or
108
+ the Taylor linearization method.
109
+ Another solution is to form the covariance function using the quadrature method
110
+ itself. For example, if you have a function $f(x, q)$, where $q$ is a zero-mean
111
+ random variable with covariance `S`,
112
+ you can form the mean and covariance function as follows:
113
+
114
+ ```python
115
+ def linearize_q_part(x):
116
+ n_dim = S.shape[0]
117
+ m_q = jnp.zeros(n_dim)
118
+ A, b, Q = functional(lambda x: f(x, q_sigma_points.points), 0. * S, m_q, S, quadrature, mode)
119
+ return A, b, Q
120
+
121
+ def cov_fn(x):
122
+ A, b, Q = linearize_q_part(x)
123
+ return Q + A @ S @ A.T
124
+
125
+ def mean_fn(x):
126
+ A, b, Q = linearize_q_part(x)
127
+ m_q = jnp.zeros(n_dim)
128
+ return b + f(x, m_q)
129
+ ```
130
+
131
+ This technique is a bit wasteful due to our current separation of duties between
132
+ the mean and covariance functions, but as we develop the library further, we
133
+ will provide a more elegant solution.
134
+ """
135
+
136
+ # make the equivalent conditional_moments model
137
+ def mean_fn(x):
138
+ return fn(x)
139
+
140
+ def cov_fn(x):
141
+ return jnp.asarray(S)
142
+
143
+ return conditional_moments(mean_fn, cov_fn, m, cov, quadrature, mode)
@@ -0,0 +1,79 @@
1
+ """Implements unscented quadrature."""
2
+
3
+ from typing import NamedTuple
4
+
5
+ import jax.numpy as jnp
6
+
7
+ from cuthbertlib.quadrature.common import SigmaPoints
8
+ from cuthbertlib.types import Array, ArrayLike
9
+
10
+ __all__ = ["weights", "UnscentedQuadrature"]
11
+
12
+
13
+ class UnscentedQuadrature(NamedTuple):
14
+ """Unscented quadrature.
15
+
16
+ Attributes:
17
+ wm: The mean weights.
18
+ wc: The covariance weights.
19
+ lamda: The lambda parameter.
20
+ """
21
+
22
+ wm: Array
23
+ wc: Array
24
+ lamda: float
25
+
26
+ def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
27
+ """Get the sigma points.
28
+
29
+ Args:
30
+ m: The mean.
31
+ chol: The Cholesky factor of the covariance.
32
+
33
+ Returns:
34
+ SigmaPoints: The sigma points.
35
+ """
36
+ m = jnp.asarray(m)
37
+ chol = jnp.asarray(chol)
38
+
39
+ n_dim = m.shape[0]
40
+ scaled_chol = jnp.sqrt(n_dim + self.lamda) * chol
41
+
42
+ zeros = jnp.zeros((1, n_dim))
43
+ sigma_points = m[None, :] + jnp.concatenate(
44
+ [zeros, scaled_chol.T, -scaled_chol.T], axis=0
45
+ )
46
+ return SigmaPoints(sigma_points, self.wm, self.wc)
47
+
48
+
49
+ def weights(
50
+ n_dim: int, alpha: float = 0.5, beta: float = 2.0, kappa: float | None = None
51
+ ) -> UnscentedQuadrature:
52
+ """Computes the weights associated with the unscented cubature method.
53
+
54
+ The number of sigma-points is 2 * n_dim.
55
+ This method is also known as the Unscented Transform, and generalizes the
56
+ `cubature.py` weights: the cubature method is a special case of the unscented
57
+ for the parameters `alpha=1.0`, `beta=0.0`, `kappa=0.0`.
58
+
59
+ Args:
60
+ n_dim: Dimension of the space.
61
+ alpha: Parameter of the unscented transform, default is 0.5.
62
+ beta: Parameter of the unscented transform, default is 2.0.
63
+ kappa: Parameter of the unscented transform, default is 3 + n_dim.
64
+
65
+ Returns:
66
+ UnscentedQuadrature: The quadrature object with the weights and sigma-points.
67
+
68
+ References:
69
+ - https://groups.seas.harvard.edu/courses/cs281/papers/unscented.pdf
70
+ """
71
+ if kappa is None:
72
+ kappa = 3.0 + n_dim
73
+
74
+ lamda = alpha**2 * (n_dim + kappa) - n_dim
75
+ wm = jnp.full(2 * n_dim + 1, 1 / (2 * (n_dim + lamda)))
76
+
77
+ wm = wm.at[0].set(lamda / (n_dim + lamda))
78
+ wc = wm.at[0].set(lamda / (n_dim + lamda) + (1 - alpha**2 + beta))
79
+ return UnscentedQuadrature(wm=wm, wc=wc, lamda=lamda)
@@ -0,0 +1,109 @@
1
+ """Utility functions (Cholesky updating) for quadrature."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+
6
+ from cuthbertlib.types import Array, ArrayLike
7
+
8
+ __all__ = ["cholesky_update_many"]
9
+
10
+
11
+ def cholesky_update_many(
12
+ chol_init: ArrayLike, update_vectors: ArrayLike, multiplier: float
13
+ ) -> Array:
14
+ r"""Update the Cholesky decomposition of a matrix with multiple update vectors.
15
+
16
+ In mathematical terms, we compute :math:`A + \sum_{i=1}^{n} \alpha v_i v_i^T`
17
+ where :math:`A` is the original matrix, :math:`v_i` are the update vectors and
18
+ :math:`\alpha` is the multiplier.
19
+
20
+ Args:
21
+ chol_init: Initial Cholesky decomposition of the matrix, :math:`A`.
22
+ update_vectors: Update vectors, :math:`v_i`.
23
+ multiplier: The multiplier, :math:`\alpha`.
24
+
25
+ Returns:
26
+ The updated Cholesky decomposition of the matrix.
27
+
28
+ Notes:
29
+ If the updated matrix does not correspond to a positive definite matrix, the
30
+ function has undefined behaviour. It is the responsibility of the caller to
31
+ ensure that the updated matrix is positive definite as we cannot check this at runtime.
32
+ """
33
+
34
+ def body(chol, update_vector):
35
+ res = _cholesky_update(chol, update_vector, multiplier=multiplier)
36
+ return res, None
37
+
38
+ final_chol, _ = jax.lax.scan(body, jnp.asarray(chol_init), update_vectors)
39
+ return final_chol
40
+
41
+
42
+ def _set_diagonal(x: Array, y: Array) -> Array:
43
+ N, _ = x.shape
44
+ i, j = jnp.diag_indices(N)
45
+ return x.at[i, j].set(y)
46
+
47
+
48
+ def _set_triu(x: Array, val: ArrayLike) -> Array:
49
+ N, _ = x.shape
50
+ i = jnp.triu_indices(N, 1)
51
+ return x.at[i].set(val)
52
+
53
+
54
+ def _cholesky_update(
55
+ chol: ArrayLike, update_vector: ArrayLike, multiplier: float = 1.0
56
+ ) -> Array:
57
+ chol = jnp.asarray(chol)
58
+ chol_diag = jnp.diag(chol)
59
+
60
+ # The algorithm in [1] is implemented as a double for loop. We can treat
61
+ # the inner loop in Algorithm 3.1 as a vector operation, and thus the
62
+ # whole algorithm as a single for loop, and hence can use a `tf.scan`
63
+ # on it.
64
+
65
+ # We use for accumulation omega and b as defined in Algorithm 3.1, since
66
+ # these are updated per iteration.
67
+
68
+ def scan_body(carry, inp):
69
+ _, _, omega, b = carry
70
+ index, diagonal_member, col = inp
71
+ omega_at_index = omega[..., index]
72
+
73
+ # Line 4
74
+ new_diagonal_member = jnp.sqrt(
75
+ jnp.square(diagonal_member) + multiplier / b * jnp.square(omega_at_index)
76
+ )
77
+ # `scaling_factor` is the same as `gamma` on Line 5.
78
+ scaling_factor = jnp.square(diagonal_member) * b + multiplier * jnp.square(
79
+ omega_at_index
80
+ )
81
+
82
+ # The following updates are the same as the for loop in lines 6-8.
83
+ omega = omega - (omega_at_index / diagonal_member)[..., None] * col
84
+ new_col = new_diagonal_member[..., None] * (
85
+ col / diagonal_member[..., None]
86
+ + (multiplier * omega_at_index / scaling_factor)[..., None] * omega
87
+ )
88
+ b = b + multiplier * jnp.square(omega_at_index / diagonal_member)
89
+ return (new_diagonal_member, new_col, omega, b), (
90
+ new_diagonal_member,
91
+ new_col,
92
+ omega,
93
+ b,
94
+ )
95
+
96
+ # We will scan over the columns.
97
+ chol = chol.T
98
+
99
+ _, (new_diag, new_chol, _, _) = jax.lax.scan(
100
+ scan_body,
101
+ (0.0, jnp.zeros_like(chol[0]), update_vector, 1.0),
102
+ (jnp.arange(0, chol.shape[0]), chol_diag, chol),
103
+ )
104
+
105
+ new_chol = new_chol.T
106
+ new_chol = _set_diagonal(new_chol, new_diag)
107
+ new_chol = _set_triu(new_chol, 0.0)
108
+ new_chol = jnp.where(jnp.isfinite(new_chol), new_chol, 0.0)
109
+ return new_chol
@@ -0,0 +1,3 @@
1
+ from cuthbertlib.resampling import killing, multinomial, systematic
2
+ from cuthbertlib.resampling.protocols import ConditionalResampling, Resampling
3
+ from cuthbertlib.resampling.utils import inverse_cdf
@@ -0,0 +1,79 @@
1
+ """Implements killing resampling."""
2
+
3
+ from functools import partial
4
+
5
+ from jax import numpy as jnp
6
+ from jax import random
7
+ from jax.scipy.special import logsumexp
8
+
9
+ from cuthbertlib.resampling import multinomial
10
+ from cuthbertlib.resampling.protocols import (
11
+ conditional_resampling_decorator,
12
+ resampling_decorator,
13
+ )
14
+ from cuthbertlib.types import Array, ArrayLike
15
+
16
+ _DESCRIPTION = """
17
+ The Killing resampling is a simple resampling mechanism that checks if
18
+ particles should be replaced or not, based on their weights.
19
+ If they should be replaced, they are replaced by another particle using
20
+ multinomial resampling on residual weights. It presents the benefit of not
21
+ "breaking" trajectories as much as multinomial resampling, and therefore is
22
+ stable in contexts where the trajectories are important (typically when dealing
23
+ with continuous-time models).
24
+
25
+ By construction, it requires the same number of sampled indices `n` as the
26
+ number of particles `logits.shape[0]`.
27
+ """
28
+
29
+
30
+ @partial(resampling_decorator, name="Killing", desc=_DESCRIPTION)
31
+ def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
32
+ logits = jnp.asarray(logits)
33
+ key_1, key_2 = random.split(key)
34
+ N = logits.shape[0]
35
+ if n != N:
36
+ raise AssertionError(
37
+ "The number of sampled indices must be equal to the number of "
38
+ f"particles for `Killing` resampling. Got {n} instead of {N}."
39
+ )
40
+
41
+ max_logit = jnp.max(logits)
42
+ log_uniforms = jnp.log(random.uniform(key_1, (N,)))
43
+
44
+ survived = log_uniforms <= logits - max_logit
45
+ if_survived = jnp.arange(N) # If the particle survives, it keeps its index
46
+ otherwise = multinomial.resampling(
47
+ key_2, logits, N
48
+ ) # otherwise, it is replaced by another particle
49
+ idx = jnp.where(survived, if_survived, otherwise)
50
+ return idx
51
+
52
+
53
+ @partial(conditional_resampling_decorator, name="Killing", desc=_DESCRIPTION)
54
+ def conditional_resampling(
55
+ key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
56
+ ) -> Array:
57
+ # Unconditional resampling
58
+ key_resample, key_shuffle = random.split(key)
59
+ idx = resampling(key_resample, logits, n)
60
+
61
+ # Conditional rolling pivot
62
+ max_logit = jnp.max(logits)
63
+
64
+ pivot_logits = _log1mexp(logits - max_logit)
65
+ pivot_logits -= jnp.log(n)
66
+ pivot_logits = pivot_logits.at[pivot_out].set(-jnp.inf)
67
+ pivot_logits_i = _log1mexp(logsumexp(pivot_logits))
68
+ pivot_logits = pivot_logits.at[pivot_out].set(pivot_logits_i)
69
+
70
+ pivot_weights = jnp.exp(pivot_logits - logsumexp(pivot_logits))
71
+ pivot = random.choice(key_shuffle, n, p=pivot_weights)
72
+ idx = jnp.roll(idx, pivot_in - pivot)
73
+ idx = idx.at[pivot_in].set(pivot_out)
74
+ return idx
75
+
76
+
77
+ def _log1mexp(x: ArrayLike) -> Array:
78
+ # There is probably a better way to do this
79
+ return jnp.log(1 - jnp.exp(x))
@@ -0,0 +1,53 @@
1
+ """Implements multinomial resampling."""
2
+
3
+ from functools import partial
4
+
5
+ import jax
6
+ from jax import numpy as jnp
7
+ from jax import random
8
+
9
+ from cuthbertlib.resampling.protocols import (
10
+ conditional_resampling_decorator,
11
+ resampling_decorator,
12
+ )
13
+ from cuthbertlib.resampling.utils import inverse_cdf
14
+ from cuthbertlib.types import Array, ArrayLike
15
+
16
+ _DESCRIPTION = """
17
+ This has higher variance than other resampling schemes as it samples from
18
+ the ancestors independently. It should only be used for illustration purposes,
19
+ or if your algorithm *REALLY REALLY* needs independent samples.
20
+ As a rule of thumb, you often don't."""
21
+
22
+
23
+ @partial(resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
24
+ def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
25
+ # In practice we don't have to sort the generated uniforms, but searchsorted
26
+ # works faster and is more stable if both inputs are sorted, so we use the
27
+ # _sorted_uniforms from N. Chopin, but still use searchsorted instead of his
28
+ # O(N) loop as our code is meant to work on GPU where searchsorted is
29
+ # O(log(N)) anyway.
30
+ # We then permute the indices to enforce exchangeability.
31
+
32
+ key_uniforms, key_shuffle = random.split(key)
33
+ sorted_uniforms = _sorted_uniforms(key_uniforms, n)
34
+ idx = inverse_cdf(sorted_uniforms, logits)
35
+ return random.permutation(key_shuffle, idx)
36
+
37
+
38
+ @partial(conditional_resampling_decorator, name="Multinomial", desc=_DESCRIPTION)
39
+ def conditional_resampling(
40
+ key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
41
+ ) -> Array:
42
+ idx = resampling(key, logits, n)
43
+ idx = idx.at[pivot_in].set(pivot_out)
44
+ return idx
45
+
46
+
47
+ @partial(jax.jit, static_argnames=("n",))
48
+ def _sorted_uniforms(key: Array, n: int) -> Array:
49
+ # This is a small modification of the code from N. Chopin to output sorted
50
+ # log-uniforms *directly*. N. Chopin's code outputs sorted uniforms.
51
+ us = random.uniform(key, (n + 1,))
52
+ z = jnp.cumsum(-jnp.log(us))
53
+ return z[:-1] / z[-1]
@@ -0,0 +1,92 @@
1
+ """Shared protocols for resampling algorithms."""
2
+
3
+ from typing import Protocol, runtime_checkable
4
+
5
+ import jax
6
+
7
+ from cuthbertlib.types import Array, ArrayLike, KeyArray
8
+
9
+
10
+ @runtime_checkable
11
+ class Resampling(Protocol):
12
+ """Protocol for resampling operations."""
13
+
14
+ def __call__(self, key: KeyArray, logits: ArrayLike, n: int) -> Array:
15
+ """Computes resampling indices according to given logits.
16
+
17
+ Args:
18
+ key: JAX PRNG key.
19
+ logits: Logits.
20
+ n: Number of indices to sample.
21
+
22
+ Returns:
23
+ Array of resampling indices.
24
+ """
25
+ ...
26
+
27
+
28
+ @runtime_checkable
29
+ class ConditionalResampling(Protocol):
30
+ """Protocol for conditional resampling operations."""
31
+
32
+ def __call__(
33
+ self,
34
+ key: KeyArray,
35
+ logits: ArrayLike,
36
+ n: int,
37
+ pivot_in: int,
38
+ pivot_out: int,
39
+ ) -> Array:
40
+ """Conditional resampling.
41
+
42
+ Args:
43
+ key: JAX PRNG key.
44
+ logits: Log-weights, possibly unnormalized.
45
+ n: Number of indices to sample.
46
+ pivot_in: Index of the particle to keep.
47
+ pivot_out: Value of the output at index `pivot_in`.
48
+
49
+ Returns:
50
+ Array of size n with indices to use for resampling.
51
+ """
52
+ ...
53
+
54
+
55
+ def resampling_decorator(func: Resampling, name: str, desc: str = "") -> Resampling:
56
+ """Decorate Resampling function with unified docstring."""
57
+ doc = f"""
58
+ {name} resampling. {desc}
59
+
60
+ Args:
61
+ key: PRNGKey to use in resampling
62
+ logits: Log-weights, possibly unnormalized.
63
+ n: Number of indices to sample.
64
+
65
+ Returns:
66
+ Array of size n with indices to use for resampling.
67
+ """
68
+
69
+ func.__doc__ = doc
70
+ return jax.jit(func, static_argnames=("n",))
71
+
72
+
73
+ def conditional_resampling_decorator(
74
+ func: ConditionalResampling, name: str, desc: str = ""
75
+ ) -> ConditionalResampling:
76
+ """Decorate ConditionalResampling function with unified docstring."""
77
+ doc = f"""
78
+ {name} conditional resampling. {desc}
79
+
80
+ Args:
81
+ key: PRNGKey to use in resampling
82
+ logits: Log-weights, possibly unnormalized.
83
+ n: Number of indices to sample
84
+ pivot_in: Index of the particle to keep
85
+ pivot_out: Value of the output at index `pivot_in`
86
+
87
+ Returns:
88
+ Array of size n with indices to use for resampling.
89
+ """
90
+
91
+ func.__doc__ = doc
92
+ return jax.jit(func, static_argnames=("n",))