cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.1.dist-info/RECORD +0 -12
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Extract marginal square root covariance from a joint square root covariance."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.linalg.tria import tria
|
|
8
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, end: int) -> Array:
|
|
12
|
+
"""Extracts square root submatrix from a joint square root matrix.
|
|
13
|
+
|
|
14
|
+
Specifically, returns B such that
|
|
15
|
+
B @ B.T = (chol_cov @ chol_cov.T)[start:end, start:end]
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
chol_cov: Generalized Cholesky factor of the covariance matrix.
|
|
19
|
+
start: Start index of the submatrix.
|
|
20
|
+
end: End index of the submatrix.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Lower triangular square root matrix of the marginal covariance matrix.
|
|
24
|
+
"""
|
|
25
|
+
chol_cov = jnp.asarray(chol_cov)
|
|
26
|
+
assert chol_cov.ndim == 2, "chol_cov must be a 2D array"
|
|
27
|
+
assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square"
|
|
28
|
+
assert start >= 0 and end <= chol_cov.shape[0], (
|
|
29
|
+
"start and end must be within the bounds of chol_cov"
|
|
30
|
+
)
|
|
31
|
+
assert start < end, "start must be less than end"
|
|
32
|
+
|
|
33
|
+
chol_cov_select_rows = chol_cov[start:end, :]
|
|
34
|
+
return tria(chol_cov_select_rows)
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""Implements inverse square root of a symmetric matrix."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.linalg.tria import tria
|
|
6
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def symmetric_inv_sqrt(
|
|
10
|
+
A: ArrayLike,
|
|
11
|
+
rtol: float | ArrayLike | None = None,
|
|
12
|
+
ignore_nan_dims: bool = False,
|
|
13
|
+
) -> Array:
|
|
14
|
+
r"""Computes the inverse square root of a symmetric matrix.
|
|
15
|
+
|
|
16
|
+
I.e., a lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for positive definite
|
|
17
|
+
$A$). Note that this is not unique and will generally not match the Cholesky factor
|
|
18
|
+
of $A^{-1}$.
|
|
19
|
+
|
|
20
|
+
For singular matrices, small singular values will be cut off reminiscent of
|
|
21
|
+
the Moore-Penrose pseudoinverse - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
22
|
+
|
|
23
|
+
In the case of singular or indefinite $A$, the output will be an approximation
|
|
24
|
+
and $L L^{\top} = A^{-1}$ will not hold in general.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
A: A symmetric matrix.
|
|
28
|
+
rtol: The relative tolerance for the singular values.
|
|
29
|
+
Cutoff for small singular values; singular values smaller than
|
|
30
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
31
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
32
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal as missing
|
|
33
|
+
and ignore all rows and columns associated with them (with result in those
|
|
34
|
+
dimensions being NaN on the diagonal and zero off-diagonal).
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
A lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for valid dimensions).
|
|
38
|
+
"""
|
|
39
|
+
arr = jnp.asarray(A)
|
|
40
|
+
|
|
41
|
+
# Check for NaNs on the diagonal (missing dimensions)
|
|
42
|
+
diag_vals = jnp.diag(arr)
|
|
43
|
+
nan_diag_mask = jnp.isnan(diag_vals) * ignore_nan_dims
|
|
44
|
+
|
|
45
|
+
# Check for dimensions whose row and column are all 0
|
|
46
|
+
zero_mask = jnp.all(arr == 0.0, axis=0) & jnp.all(arr == 0.0, axis=1)
|
|
47
|
+
|
|
48
|
+
nan_mask = nan_diag_mask | zero_mask
|
|
49
|
+
|
|
50
|
+
# Sort to group valid dimensions first (needed for SVD to work correctly)
|
|
51
|
+
argsort = jnp.argsort(nan_mask, stable=True)
|
|
52
|
+
arr_sorted = arr[argsort[:, None], argsort]
|
|
53
|
+
nan_mask_sorted = nan_mask[argsort]
|
|
54
|
+
|
|
55
|
+
# Zero out invalid dimensions before computation
|
|
56
|
+
invalid_mask_2d = ((nan_mask_sorted[:, None]) | (nan_mask_sorted[None, :])) & (
|
|
57
|
+
ignore_nan_dims
|
|
58
|
+
)
|
|
59
|
+
arr_sorted = jnp.where(invalid_mask_2d, 0.0, arr_sorted)
|
|
60
|
+
|
|
61
|
+
# Compute inverse square root on sorted, masked matrix
|
|
62
|
+
L_sorted = _symmetric_inv_sqrt(arr_sorted, rtol)
|
|
63
|
+
|
|
64
|
+
# Post-process: zero out invalid rows/cols, set NaN on invalid diagonal
|
|
65
|
+
L_sorted = jnp.where(invalid_mask_2d, 0.0, L_sorted)
|
|
66
|
+
diag_L = jnp.where(nan_mask_sorted, jnp.nan, jnp.diag(L_sorted))
|
|
67
|
+
L_sorted = L_sorted.at[jnp.diag_indices_from(L_sorted)].set(diag_L)
|
|
68
|
+
|
|
69
|
+
# Un-sort to restore original order
|
|
70
|
+
inv_argsort = jnp.argsort(argsort)
|
|
71
|
+
L = L_sorted[inv_argsort[:, None], inv_argsort]
|
|
72
|
+
|
|
73
|
+
return L
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _symmetric_inv_sqrt(A: ArrayLike, rtol: float | ArrayLike | None = None) -> Array:
|
|
77
|
+
"""Implementation of symmetric inverse square root without NaN handling."""
|
|
78
|
+
arr = jnp.asarray(A)
|
|
79
|
+
|
|
80
|
+
# From https://github.com/jax-ml/jax/blob/75d8702023fca6fe4a223bf1e08545c1c80581c0/jax/_src/numpy/linalg.py#L972
|
|
81
|
+
if rtol is None:
|
|
82
|
+
max_rows_cols = max(arr.shape[-2:])
|
|
83
|
+
rtol = jnp.asarray(10.0 * max_rows_cols * jnp.finfo(arr.dtype).eps)
|
|
84
|
+
u, s, _ = jnp.linalg.svd(arr, full_matrices=False, hermitian=True)
|
|
85
|
+
cutoff = rtol * s[0]
|
|
86
|
+
# Use 0 for invalid singular values to avoid inf/NaN propagation in tria
|
|
87
|
+
valid_mask = s > cutoff
|
|
88
|
+
inv_sqrt_s = jnp.where(valid_mask, 1.0 / jnp.sqrt(s), 0.0).astype(u.dtype)
|
|
89
|
+
B = u * inv_sqrt_s # Square root but not lower triangular
|
|
90
|
+
L = tria(B) # Make lower triangular
|
|
91
|
+
# Mark dimensions with all 0 rows and columns as NaN
|
|
92
|
+
zero_dims_mask = jnp.all(L == 0.0, axis=0) & jnp.all(L == 0.0, axis=1)
|
|
93
|
+
L = jnp.where(zero_dims_mask[:, None] | zero_dims_mask[None, :], jnp.nan, L)
|
|
94
|
+
return L
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def chol_cov_with_nans_to_cov(chol_cov: ArrayLike) -> Array:
|
|
98
|
+
"""Converts a Cholesky factor to a covariance matrix.
|
|
99
|
+
|
|
100
|
+
NaNs on the diagonal specify dimensions to be ignored.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
chol_cov: A Cholesky factor of a covariance matrix with NaNs on the diagonal
|
|
104
|
+
specifying dimensions to be ignored.
|
|
105
|
+
|
|
106
|
+
Returns:
|
|
107
|
+
A covariance matrix equivalent to chol_cov @ chol_cov.T in dimensions where
|
|
108
|
+
the Cholesky factor is valid and for invalid dimensions (ones with NaN on the
|
|
109
|
+
diagonal in chol_cov) with NaN on the diagonal and zero off-diagonal.
|
|
110
|
+
"""
|
|
111
|
+
chol_cov = jnp.asarray(chol_cov)
|
|
112
|
+
|
|
113
|
+
nan_mask = jnp.isnan(jnp.diag(chol_cov))
|
|
114
|
+
|
|
115
|
+
# Set all rows and columns with invalid diagonal to zero
|
|
116
|
+
chol_cov = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0, chol_cov)
|
|
117
|
+
|
|
118
|
+
# Calculate the covariance matrix
|
|
119
|
+
cov = chol_cov @ chol_cov.T
|
|
120
|
+
|
|
121
|
+
# Set the diagonal to NaN
|
|
122
|
+
cov = cov.at[jnp.diag_indices_from(cov)].set(
|
|
123
|
+
jnp.where(nan_mask, jnp.nan, jnp.diag(cov))
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return cov
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Implements triangularization operator a matrix via QR decomposition."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.types import Array
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def tria(A: Array) -> Array:
|
|
9
|
+
r"""A triangularization operator using QR decomposition.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
A: The matrix to triangularize.
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
A lower triangular matrix $R$ such that $R R^\top = A A^\top$.
|
|
16
|
+
|
|
17
|
+
Reference:
|
|
18
|
+
[Arasaratnam and Haykin (2008)](https://ieeexplore.ieee.org/document/4524036): Square-Root Quadrature Kalman Filtering
|
|
19
|
+
"""
|
|
20
|
+
_, R = jax.scipy.linalg.qr(A.T, mode="economic")
|
|
21
|
+
return R.T
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from cuthbertlib.linalg import symmetric_inv_sqrt
|
|
2
|
+
from cuthbertlib.linearize.log_density import (
|
|
3
|
+
linearize_log_density,
|
|
4
|
+
linearize_log_density_given_chol_cov,
|
|
5
|
+
)
|
|
6
|
+
from cuthbertlib.linearize.moments import linearize_moments
|
|
7
|
+
from cuthbertlib.linearize.taylor import linearize_taylor
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Implements linearization of conditional log densities."""
|
|
2
|
+
|
|
3
|
+
from typing import overload
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax import grad, hessian, jacobian
|
|
7
|
+
|
|
8
|
+
from cuthbertlib.linalg import chol_cov_with_nans_to_cov, symmetric_inv_sqrt
|
|
9
|
+
from cuthbertlib.types import (
|
|
10
|
+
Array,
|
|
11
|
+
ArrayLike,
|
|
12
|
+
ArrayTree,
|
|
13
|
+
LogConditionalDensity,
|
|
14
|
+
LogConditionalDensityAux,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@overload
|
|
19
|
+
def linearize_log_density(
|
|
20
|
+
log_density: LogConditionalDensity,
|
|
21
|
+
x: ArrayLike,
|
|
22
|
+
y: ArrayLike,
|
|
23
|
+
has_aux: bool = False,
|
|
24
|
+
rtol: float | None = None,
|
|
25
|
+
ignore_nan_dims: bool = False,
|
|
26
|
+
) -> tuple[Array, Array, Array]: ...
|
|
27
|
+
@overload
|
|
28
|
+
def linearize_log_density(
|
|
29
|
+
log_density: LogConditionalDensityAux,
|
|
30
|
+
x: ArrayLike,
|
|
31
|
+
y: ArrayLike,
|
|
32
|
+
has_aux: bool = True,
|
|
33
|
+
rtol: float | None = None,
|
|
34
|
+
ignore_nan_dims: bool = False,
|
|
35
|
+
) -> tuple[Array, Array, Array, ArrayTree]: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def linearize_log_density(
|
|
39
|
+
log_density: LogConditionalDensity | LogConditionalDensityAux,
|
|
40
|
+
x: ArrayLike,
|
|
41
|
+
y: ArrayLike,
|
|
42
|
+
has_aux: bool = False,
|
|
43
|
+
rtol: float | None = None,
|
|
44
|
+
ignore_nan_dims: bool = False,
|
|
45
|
+
) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
|
|
46
|
+
r"""Linearizes a conditional log density around given points.
|
|
47
|
+
|
|
48
|
+
The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
|
|
49
|
+
$(H, d, L)$ if `log_density` is of the form
|
|
50
|
+
|
|
51
|
+
$$
|
|
52
|
+
\log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const}.
|
|
53
|
+
$$
|
|
54
|
+
|
|
55
|
+
The Cholesky factor of the covariance is calculated using the negative Hessian
|
|
56
|
+
of `log_density` with respect to `y` as the precision matrix.
|
|
57
|
+
`symmetric_inv_sqrt` is used to calculate the inverse square root by
|
|
58
|
+
ignoring any singular values that are sufficiently close to zero
|
|
59
|
+
(this is a projection in the case the Hessian is not positive definite).
|
|
60
|
+
|
|
61
|
+
Alternatively, the Cholesky factor can be provided directly
|
|
62
|
+
in `linearize_log_density_given_chol_cov`.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
log_density: A conditional log density of y given x. Returns a scalar.
|
|
66
|
+
x: The input points.
|
|
67
|
+
y: The output points.
|
|
68
|
+
has_aux: Whether `log_density` returns an auxiliary value.
|
|
69
|
+
rtol: The relative tolerance for the singular values of the precision matrix
|
|
70
|
+
when passed to `symmetric_inv_sqrt`.
|
|
71
|
+
Cutoff for small singular values; singular values smaller than
|
|
72
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
73
|
+
The default is determined based on the floating point precision of the dtype.
|
|
74
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
75
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
76
|
+
precision matrix as missing and ignore all rows and columns associated with
|
|
77
|
+
them.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Linearized matrix, shift, and Cholesky factor of the covariance matrix.
|
|
81
|
+
The auxiliary value is also returned if `has_aux` is `True`.
|
|
82
|
+
"""
|
|
83
|
+
prec_and_maybe_aux = hessian(log_density, 1, has_aux=has_aux)(x, y)
|
|
84
|
+
prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux
|
|
85
|
+
if ignore_nan_dims:
|
|
86
|
+
prec_diag = jnp.diag(prec)
|
|
87
|
+
nan_mask = jnp.isnan(y) | jnp.isnan(prec_diag)
|
|
88
|
+
prec = prec.at[jnp.diag_indices_from(prec)].set(
|
|
89
|
+
jnp.where(nan_mask, jnp.nan, prec_diag)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
chol_cov = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)
|
|
93
|
+
mat, shift, *extra = linearize_log_density_given_chol_cov(
|
|
94
|
+
log_density, x, y, chol_cov, has_aux=has_aux, ignore_nan_dims=ignore_nan_dims
|
|
95
|
+
)
|
|
96
|
+
return mat, shift, chol_cov, *extra
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@overload
|
|
100
|
+
def linearize_log_density_given_chol_cov(
|
|
101
|
+
log_density: LogConditionalDensity,
|
|
102
|
+
x: ArrayLike,
|
|
103
|
+
y: ArrayLike,
|
|
104
|
+
chol_cov: ArrayLike,
|
|
105
|
+
has_aux: bool = False,
|
|
106
|
+
ignore_nan_dims: bool = False,
|
|
107
|
+
) -> tuple[Array, Array]: ...
|
|
108
|
+
@overload
|
|
109
|
+
def linearize_log_density_given_chol_cov(
|
|
110
|
+
log_density: LogConditionalDensityAux,
|
|
111
|
+
x: ArrayLike,
|
|
112
|
+
y: ArrayLike,
|
|
113
|
+
chol_cov: ArrayLike,
|
|
114
|
+
has_aux: bool = True,
|
|
115
|
+
ignore_nan_dims: bool = False,
|
|
116
|
+
) -> tuple[Array, Array, ArrayTree]: ...
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def linearize_log_density_given_chol_cov(
|
|
120
|
+
log_density: LogConditionalDensity | LogConditionalDensityAux,
|
|
121
|
+
x: ArrayLike,
|
|
122
|
+
y: ArrayLike,
|
|
123
|
+
chol_cov: ArrayLike,
|
|
124
|
+
has_aux: bool = False,
|
|
125
|
+
ignore_nan_dims: bool = False,
|
|
126
|
+
) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
|
|
127
|
+
r"""Linearizes a conditional log density around given points.
|
|
128
|
+
|
|
129
|
+
The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
|
|
130
|
+
$(H, d)$ if `log_density` is of the form
|
|
131
|
+
|
|
132
|
+
$$
|
|
133
|
+
\log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const},
|
|
134
|
+
$$
|
|
135
|
+
|
|
136
|
+
where $L$ is the argument `chol_cov`.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
log_density: A conditional log density of y given x. Returns a scalar.
|
|
140
|
+
x: The input points.
|
|
141
|
+
y: The output points.
|
|
142
|
+
chol_cov: The Cholesky factor of the covariance matrix of the Gaussian.
|
|
143
|
+
has_aux: Whether `log_density` returns an auxiliary value.
|
|
144
|
+
ignore_nan_dims: Whether to ignore dimensions with NaN on the diagonal of the
|
|
145
|
+
precision matrix or in y.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Linearized matrix and shift. The auxiliary value is also returned if `has_aux` is `True`.
|
|
149
|
+
"""
|
|
150
|
+
chol_cov = jnp.asarray(chol_cov)
|
|
151
|
+
|
|
152
|
+
cov = (
|
|
153
|
+
chol_cov_with_nans_to_cov(chol_cov)
|
|
154
|
+
if ignore_nan_dims
|
|
155
|
+
else chol_cov @ chol_cov.T
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if has_aux:
|
|
159
|
+
|
|
160
|
+
def grad_log_density_wrapper_aux(x, y):
|
|
161
|
+
g, aux = grad(log_density, 1, has_aux=True)(x, y)
|
|
162
|
+
return g, (g, aux)
|
|
163
|
+
|
|
164
|
+
jac, (g, *extra) = jacobian(grad_log_density_wrapper_aux, 0, has_aux=True)(x, y)
|
|
165
|
+
else:
|
|
166
|
+
|
|
167
|
+
def grad_log_density_wrapper(x, y):
|
|
168
|
+
g = grad(log_density, 1)(x, y)
|
|
169
|
+
return g, (g,)
|
|
170
|
+
|
|
171
|
+
jac, (g, *extra) = jacobian(grad_log_density_wrapper, 0, has_aux=True)(x, y)
|
|
172
|
+
|
|
173
|
+
mat = cov @ jac
|
|
174
|
+
shift = y - mat @ x + cov @ g
|
|
175
|
+
return mat, shift, *extra
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
"""Implements moment-based linearization."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, cast, overload
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax.typing import ArrayLike
|
|
7
|
+
|
|
8
|
+
from cuthbertlib.types import Array, ArrayTree
|
|
9
|
+
|
|
10
|
+
MeanAndCholCovFunc = Callable[[ArrayLike], tuple[Array, Array]]
|
|
11
|
+
MeanAndCholCovFuncAux = Callable[[ArrayLike], tuple[Array, Array, ArrayTree]]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@overload
|
|
15
|
+
def linearize_moments(
|
|
16
|
+
mean_and_chol_cov_function: MeanAndCholCovFunc,
|
|
17
|
+
x: ArrayLike,
|
|
18
|
+
has_aux: bool = False,
|
|
19
|
+
) -> tuple[Array, Array, Array]: ...
|
|
20
|
+
@overload
|
|
21
|
+
def linearize_moments(
|
|
22
|
+
mean_and_chol_cov_function: MeanAndCholCovFuncAux,
|
|
23
|
+
x: ArrayLike,
|
|
24
|
+
has_aux: bool = True,
|
|
25
|
+
) -> tuple[Array, Array, Array, ArrayTree]: ...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def linearize_moments(
|
|
29
|
+
mean_and_chol_cov_function: MeanAndCholCovFunc | MeanAndCholCovFuncAux,
|
|
30
|
+
x: ArrayLike,
|
|
31
|
+
has_aux: bool = False,
|
|
32
|
+
) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
|
|
33
|
+
r"""Linearizes conditional mean and chol_cov functions into a linear-Gaussian form.
|
|
34
|
+
|
|
35
|
+
Takes a function `mean_and_chol_cov_function(x)` that returns the
|
|
36
|
+
conditional mean and Cholesky factor of the covariance matrix of the distribution
|
|
37
|
+
$p(y \mid x)$ for a given input `x`.
|
|
38
|
+
|
|
39
|
+
Returns $(H, d, L)$ defining a linear-Gaussian approximation to the conditional
|
|
40
|
+
distribution $p(y \mid x) \approx N(y \mid H x + d, L L^\top)$.
|
|
41
|
+
|
|
42
|
+
`mean_and_chol_cov_function` has the following signature with `has_aux` = False:
|
|
43
|
+
```
|
|
44
|
+
m, chol = mean_and_chol_cov_function(x)
|
|
45
|
+
```
|
|
46
|
+
or with `has_aux` = True:
|
|
47
|
+
```
|
|
48
|
+
m, chol, aux = mean_and_chol_cov_function(x)
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
mean_and_chol_cov_function: A callable that returns the conditional mean and
|
|
53
|
+
Cholesky factor of the covariance matrix of the distribution for a given
|
|
54
|
+
input.
|
|
55
|
+
x: The point to linearize around.
|
|
56
|
+
has_aux: Whether `mean_and_chol_cov_function` returns an auxiliary value.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Linearized matrix, shift, and Cholesky factor of the covariance matrix.
|
|
60
|
+
The auxiliary value is also returned if `has_aux` is `True`.
|
|
61
|
+
|
|
62
|
+
References:
|
|
63
|
+
- [sqrt-parallel-smoothers](https://github.com/EEA-sensors/sqrt-parallel-smoothers/blob/main/parsmooth/linearization/_extended.py)
|
|
64
|
+
"""
|
|
65
|
+
if has_aux:
|
|
66
|
+
mean_and_chol_cov_function = cast(
|
|
67
|
+
MeanAndCholCovFuncAux, mean_and_chol_cov_function
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def mean_and_chol_cov_function_wrapper_aux(
|
|
71
|
+
x: ArrayLike,
|
|
72
|
+
) -> tuple[Array, tuple[Array, Array, ArrayTree]]:
|
|
73
|
+
mean, chol_cov, aux = mean_and_chol_cov_function(x)
|
|
74
|
+
return mean, (mean, chol_cov, aux)
|
|
75
|
+
|
|
76
|
+
F, (m, *extra) = jax.jacfwd(
|
|
77
|
+
mean_and_chol_cov_function_wrapper_aux, has_aux=True
|
|
78
|
+
)(x)
|
|
79
|
+
|
|
80
|
+
else:
|
|
81
|
+
mean_and_chol_cov_function = cast(
|
|
82
|
+
MeanAndCholCovFunc, mean_and_chol_cov_function
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def mean_and_chol_cov_function_wrapper(
|
|
86
|
+
x: ArrayLike,
|
|
87
|
+
) -> tuple[Array, tuple[Array, Array]]:
|
|
88
|
+
mean, chol_cov = mean_and_chol_cov_function(x)
|
|
89
|
+
return mean, (mean, chol_cov)
|
|
90
|
+
|
|
91
|
+
F, (m, *extra) = jax.jacfwd(mean_and_chol_cov_function_wrapper, has_aux=True)(x)
|
|
92
|
+
|
|
93
|
+
b = m - F @ x
|
|
94
|
+
return F, b, *extra
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Implements Taylor-like linearization."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable, overload
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from jax import numpy as jnp
|
|
7
|
+
from jax.typing import ArrayLike
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.linalg import symmetric_inv_sqrt
|
|
10
|
+
from cuthbertlib.types import Array, ArrayTree
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@overload
|
|
14
|
+
def linearize_taylor(
|
|
15
|
+
log_potential: Callable[[ArrayLike], Array],
|
|
16
|
+
x: ArrayLike,
|
|
17
|
+
has_aux: bool = False,
|
|
18
|
+
rtol: float | None = None,
|
|
19
|
+
ignore_nan_dims: bool = False,
|
|
20
|
+
) -> tuple[Array, Array]: ...
|
|
21
|
+
@overload
|
|
22
|
+
def linearize_taylor(
|
|
23
|
+
log_potential: Callable[[ArrayLike], tuple[Array, ArrayTree]],
|
|
24
|
+
x: ArrayLike,
|
|
25
|
+
has_aux: bool = True,
|
|
26
|
+
rtol: float | None = None,
|
|
27
|
+
ignore_nan_dims: bool = False,
|
|
28
|
+
) -> tuple[Array, Array, ArrayTree]: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def linearize_taylor(
|
|
32
|
+
log_potential: Callable[[ArrayLike], Array]
|
|
33
|
+
| Callable[[ArrayLike], tuple[Array, ArrayTree]],
|
|
34
|
+
x: ArrayLike,
|
|
35
|
+
has_aux: bool = False,
|
|
36
|
+
rtol: float | None = None,
|
|
37
|
+
ignore_nan_dims: bool = False,
|
|
38
|
+
) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
|
|
39
|
+
r"""Linearizes a log potential function around a given point using Taylor expansion.
|
|
40
|
+
|
|
41
|
+
Unlike the other linearization methods, this applies to a potential function
|
|
42
|
+
with no required notion of observation $y$ or conditional dependence.
|
|
43
|
+
|
|
44
|
+
Instead we have the linearization
|
|
45
|
+
|
|
46
|
+
$$
|
|
47
|
+
\log G(x) = -\frac{1}{2} (x - m)^\top (L L^\top)^{-1} (x - m).
|
|
48
|
+
$$
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
log_potential: A callable that returns a non-negative scalar. Does not need
|
|
52
|
+
to be a normalized probability density in its input.
|
|
53
|
+
x: The point to linearize around.
|
|
54
|
+
has_aux: Whether `log_potential` returns an auxiliary value.
|
|
55
|
+
rtol: The relative tolerance for the singular values of the precision matrix
|
|
56
|
+
when passed to `symmetric_inv_sqrt`.
|
|
57
|
+
Cutoff for small singular values; singular values smaller than
|
|
58
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
59
|
+
The default is determined based on the floating point precision of the dtype.
|
|
60
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
61
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
62
|
+
precision matrix as missing and ignore all rows and columns associated with
|
|
63
|
+
them.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Linearized mean and Cholesky factor of the covariance matrix.
|
|
67
|
+
The auxiliary value is also returned if `has_aux` is `True`.
|
|
68
|
+
"""
|
|
69
|
+
g_and_maybe_aux = jax.grad(log_potential, has_aux=has_aux)(x)
|
|
70
|
+
prec_and_maybe_aux = jax.hessian(log_potential, has_aux=has_aux)(x)
|
|
71
|
+
|
|
72
|
+
g, aux = g_and_maybe_aux if has_aux else (g_and_maybe_aux, None)
|
|
73
|
+
prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux
|
|
74
|
+
|
|
75
|
+
L = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)
|
|
76
|
+
|
|
77
|
+
# Change nans on diag to zeros for L @ L.T @ g, still retain nans on diag for L for bookkeeping
|
|
78
|
+
# If ignore_nan_dims, change all rows and columns with nans on the diagonal to 0
|
|
79
|
+
L_diag = jnp.diag(L)
|
|
80
|
+
nan_mask = jnp.isnan(L_diag) * ignore_nan_dims
|
|
81
|
+
L_temp = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0.0, L)
|
|
82
|
+
m = x + L_temp @ L_temp.T @ g
|
|
83
|
+
return (m, L, aux) if has_aux else (m, L)
|
|
@@ -0,0 +1,4 @@
|
|
|
1
|
+
from cuthbertlib.quadrature import cubature, gauss_hermite, unscented
|
|
2
|
+
from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
|
|
3
|
+
from cuthbertlib.quadrature.linearize import conditional_moments, functional
|
|
4
|
+
from cuthbertlib.quadrature.utils import cholesky_update_many
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Common types and protocols for quadrature."""
|
|
2
|
+
|
|
3
|
+
from typing import NamedTuple, Protocol, Self, runtime_checkable
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.linalg import tria
|
|
8
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
9
|
+
|
|
10
|
+
__all__ = ["SigmaPoints", "Quadrature"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SigmaPoints(NamedTuple):
|
|
14
|
+
"""Represents integration (quadrature) sigma points as a collection of points.
|
|
15
|
+
|
|
16
|
+
Weights correspond to mean and covariance calculations.
|
|
17
|
+
|
|
18
|
+
Attributes:
|
|
19
|
+
points: The sigma points.
|
|
20
|
+
wm: The mean weights.
|
|
21
|
+
wc: The covariance weights.
|
|
22
|
+
|
|
23
|
+
Methods:
|
|
24
|
+
mean: Computes the mean of the sigma points.
|
|
25
|
+
covariance: Computes the covariance between the sigma points and the other
|
|
26
|
+
sigma points (or itself).
|
|
27
|
+
sqrt: Computes a square root of the covariance matrix of the sigma points.
|
|
28
|
+
|
|
29
|
+
References:
|
|
30
|
+
Simo Särkkä, Lennard Svensson. *Bayesian Filtering and Smoothing.*
|
|
31
|
+
In: Cambridge University Press 2023.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
points: Array
|
|
35
|
+
wm: Array
|
|
36
|
+
wc: Array
|
|
37
|
+
|
|
38
|
+
@property
|
|
39
|
+
def mean(self) -> Array:
|
|
40
|
+
"""Computes the mean of the sigma points.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
The mean of the sigma points.
|
|
44
|
+
"""
|
|
45
|
+
return jnp.dot(self.wm, self.points)
|
|
46
|
+
|
|
47
|
+
# Should this be property too?
|
|
48
|
+
def covariance(self, other: Self | None = None) -> Array:
|
|
49
|
+
"""Computes the covariance between the sigma points and the other sigma points.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
other: The optional other sigma points.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
The covariance matrix.
|
|
56
|
+
"""
|
|
57
|
+
mean = self.mean
|
|
58
|
+
if other is None:
|
|
59
|
+
return _cov(self.wc, self.points, mean, self.points, mean)
|
|
60
|
+
|
|
61
|
+
other_mean = other.mean
|
|
62
|
+
return _cov(self.wc, self.points, mean, other.points, other_mean)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def sqrt(self) -> Array:
|
|
66
|
+
"""Computes the square root of the covariance matrix of the sigma points.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
The square root of the covariance matrix.
|
|
70
|
+
"""
|
|
71
|
+
sqrt = jnp.sqrt(self.wc[:, None]) * (self.points - self.mean[None, :])
|
|
72
|
+
sqrt = tria(sqrt.T)
|
|
73
|
+
return sqrt
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@runtime_checkable
|
|
77
|
+
class Quadrature(Protocol):
|
|
78
|
+
"""Protocol for quadrature methods."""
|
|
79
|
+
|
|
80
|
+
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
|
|
81
|
+
"""Get the sigma points.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
m: The mean.
|
|
85
|
+
chol: The Cholesky factor of the covariance.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
SigmaPoints: The sigma points.
|
|
89
|
+
"""
|
|
90
|
+
...
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _cov(
|
|
94
|
+
wc: Array,
|
|
95
|
+
x_pts: Array,
|
|
96
|
+
x_mean: Array,
|
|
97
|
+
y_points: Array,
|
|
98
|
+
y_mean: Array,
|
|
99
|
+
) -> Array:
|
|
100
|
+
one = (x_pts - x_mean[None, :]).T * wc[None, :]
|
|
101
|
+
two = y_points - y_mean[None, :]
|
|
102
|
+
return jnp.dot(one, two)
|