cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. cuthbertlib/discrete/__init__.py +0 -0
  29. cuthbertlib/discrete/filtering.py +49 -0
  30. cuthbertlib/discrete/smoothing.py +35 -0
  31. cuthbertlib/kalman/__init__.py +4 -0
  32. cuthbertlib/kalman/filtering.py +213 -0
  33. cuthbertlib/kalman/generate.py +85 -0
  34. cuthbertlib/kalman/sampling.py +68 -0
  35. cuthbertlib/kalman/smoothing.py +121 -0
  36. cuthbertlib/linalg/__init__.py +7 -0
  37. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  38. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  39. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  40. cuthbertlib/linalg/tria.py +21 -0
  41. cuthbertlib/linearize/__init__.py +7 -0
  42. cuthbertlib/linearize/log_density.py +175 -0
  43. cuthbertlib/linearize/moments.py +94 -0
  44. cuthbertlib/linearize/taylor.py +83 -0
  45. cuthbertlib/quadrature/__init__.py +4 -0
  46. cuthbertlib/quadrature/common.py +102 -0
  47. cuthbertlib/quadrature/cubature.py +73 -0
  48. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  49. cuthbertlib/quadrature/linearize.py +143 -0
  50. cuthbertlib/quadrature/unscented.py +79 -0
  51. cuthbertlib/quadrature/utils.py +109 -0
  52. cuthbertlib/resampling/__init__.py +3 -0
  53. cuthbertlib/resampling/killing.py +79 -0
  54. cuthbertlib/resampling/multinomial.py +53 -0
  55. cuthbertlib/resampling/protocols.py +92 -0
  56. cuthbertlib/resampling/systematic.py +78 -0
  57. cuthbertlib/resampling/utils.py +82 -0
  58. cuthbertlib/smc/__init__.py +0 -0
  59. cuthbertlib/smc/ess.py +24 -0
  60. cuthbertlib/smc/smoothing/__init__.py +0 -0
  61. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  62. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  63. cuthbertlib/smc/smoothing/protocols.py +44 -0
  64. cuthbertlib/smc/smoothing/tracing.py +45 -0
  65. cuthbertlib/stats/__init__.py +0 -0
  66. cuthbertlib/stats/multivariate_normal.py +102 -0
  67. cuthbert-0.0.2.dist-info/RECORD +0 -12
  68. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  69. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
  70. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,126 @@
1
+ """Implements inverse square root of a symmetric matrix."""
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from cuthbertlib.linalg.tria import tria
6
+ from cuthbertlib.types import Array, ArrayLike
7
+
8
+
9
+ def symmetric_inv_sqrt(
10
+ A: ArrayLike,
11
+ rtol: float | ArrayLike | None = None,
12
+ ignore_nan_dims: bool = False,
13
+ ) -> Array:
14
+ r"""Computes the inverse square root of a symmetric matrix.
15
+
16
+ I.e., a lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for positive definite
17
+ $A$). Note that this is not unique and will generally not match the Cholesky factor
18
+ of $A^{-1}$.
19
+
20
+ For singular matrices, small singular values will be cut off reminiscent of
21
+ the Moore-Penrose pseudoinverse - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
22
+
23
+ In the case of singular or indefinite $A$, the output will be an approximation
24
+ and $L L^{\top} = A^{-1}$ will not hold in general.
25
+
26
+ Args:
27
+ A: A symmetric matrix.
28
+ rtol: The relative tolerance for the singular values.
29
+ Cutoff for small singular values; singular values smaller than
30
+ `rtol * largest_singular_value` are treated as zero.
31
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
32
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal as missing
33
+ and ignore all rows and columns associated with them (with result in those
34
+ dimensions being NaN on the diagonal and zero off-diagonal).
35
+
36
+ Returns:
37
+ A lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for valid dimensions).
38
+ """
39
+ arr = jnp.asarray(A)
40
+
41
+ # Check for NaNs on the diagonal (missing dimensions)
42
+ diag_vals = jnp.diag(arr)
43
+ nan_diag_mask = jnp.isnan(diag_vals) * ignore_nan_dims
44
+
45
+ # Check for dimensions whose row and column are all 0
46
+ zero_mask = jnp.all(arr == 0.0, axis=0) & jnp.all(arr == 0.0, axis=1)
47
+
48
+ nan_mask = nan_diag_mask | zero_mask
49
+
50
+ # Sort to group valid dimensions first (needed for SVD to work correctly)
51
+ argsort = jnp.argsort(nan_mask, stable=True)
52
+ arr_sorted = arr[argsort[:, None], argsort]
53
+ nan_mask_sorted = nan_mask[argsort]
54
+
55
+ # Zero out invalid dimensions before computation
56
+ invalid_mask_2d = ((nan_mask_sorted[:, None]) | (nan_mask_sorted[None, :])) & (
57
+ ignore_nan_dims
58
+ )
59
+ arr_sorted = jnp.where(invalid_mask_2d, 0.0, arr_sorted)
60
+
61
+ # Compute inverse square root on sorted, masked matrix
62
+ L_sorted = _symmetric_inv_sqrt(arr_sorted, rtol)
63
+
64
+ # Post-process: zero out invalid rows/cols, set NaN on invalid diagonal
65
+ L_sorted = jnp.where(invalid_mask_2d, 0.0, L_sorted)
66
+ diag_L = jnp.where(nan_mask_sorted, jnp.nan, jnp.diag(L_sorted))
67
+ L_sorted = L_sorted.at[jnp.diag_indices_from(L_sorted)].set(diag_L)
68
+
69
+ # Un-sort to restore original order
70
+ inv_argsort = jnp.argsort(argsort)
71
+ L = L_sorted[inv_argsort[:, None], inv_argsort]
72
+
73
+ return L
74
+
75
+
76
+ def _symmetric_inv_sqrt(A: ArrayLike, rtol: float | ArrayLike | None = None) -> Array:
77
+ """Implementation of symmetric inverse square root without NaN handling."""
78
+ arr = jnp.asarray(A)
79
+
80
+ # From https://github.com/jax-ml/jax/blob/75d8702023fca6fe4a223bf1e08545c1c80581c0/jax/_src/numpy/linalg.py#L972
81
+ if rtol is None:
82
+ max_rows_cols = max(arr.shape[-2:])
83
+ rtol = jnp.asarray(10.0 * max_rows_cols * jnp.finfo(arr.dtype).eps)
84
+ u, s, _ = jnp.linalg.svd(arr, full_matrices=False, hermitian=True)
85
+ cutoff = rtol * s[0]
86
+ # Use 0 for invalid singular values to avoid inf/NaN propagation in tria
87
+ valid_mask = s > cutoff
88
+ inv_sqrt_s = jnp.where(valid_mask, 1.0 / jnp.sqrt(s), 0.0).astype(u.dtype)
89
+ B = u * inv_sqrt_s # Square root but not lower triangular
90
+ L = tria(B) # Make lower triangular
91
+ # Mark dimensions with all 0 rows and columns as NaN
92
+ zero_dims_mask = jnp.all(L == 0.0, axis=0) & jnp.all(L == 0.0, axis=1)
93
+ L = jnp.where(zero_dims_mask[:, None] | zero_dims_mask[None, :], jnp.nan, L)
94
+ return L
95
+
96
+
97
+ def chol_cov_with_nans_to_cov(chol_cov: ArrayLike) -> Array:
98
+ """Converts a Cholesky factor to a covariance matrix.
99
+
100
+ NaNs on the diagonal specify dimensions to be ignored.
101
+
102
+ Args:
103
+ chol_cov: A Cholesky factor of a covariance matrix with NaNs on the diagonal
104
+ specifying dimensions to be ignored.
105
+
106
+ Returns:
107
+ A covariance matrix equivalent to chol_cov @ chol_cov.T in dimensions where
108
+ the Cholesky factor is valid and for invalid dimensions (ones with NaN on the
109
+ diagonal in chol_cov) with NaN on the diagonal and zero off-diagonal.
110
+ """
111
+ chol_cov = jnp.asarray(chol_cov)
112
+
113
+ nan_mask = jnp.isnan(jnp.diag(chol_cov))
114
+
115
+ # Set all rows and columns with invalid diagonal to zero
116
+ chol_cov = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0, chol_cov)
117
+
118
+ # Calculate the covariance matrix
119
+ cov = chol_cov @ chol_cov.T
120
+
121
+ # Set the diagonal to NaN
122
+ cov = cov.at[jnp.diag_indices_from(cov)].set(
123
+ jnp.where(nan_mask, jnp.nan, jnp.diag(cov))
124
+ )
125
+
126
+ return cov
@@ -0,0 +1,21 @@
1
+ """Implements triangularization operator a matrix via QR decomposition."""
2
+
3
+ import jax
4
+
5
+ from cuthbertlib.types import Array
6
+
7
+
8
+ def tria(A: Array) -> Array:
9
+ r"""A triangularization operator using QR decomposition.
10
+
11
+ Args:
12
+ A: The matrix to triangularize.
13
+
14
+ Returns:
15
+ A lower triangular matrix $R$ such that $R R^\top = A A^\top$.
16
+
17
+ Reference:
18
+ [Arasaratnam and Haykin (2008)](https://ieeexplore.ieee.org/document/4524036): Square-Root Quadrature Kalman Filtering
19
+ """
20
+ _, R = jax.scipy.linalg.qr(A.T, mode="economic")
21
+ return R.T
@@ -0,0 +1,7 @@
1
+ from cuthbertlib.linalg import symmetric_inv_sqrt
2
+ from cuthbertlib.linearize.log_density import (
3
+ linearize_log_density,
4
+ linearize_log_density_given_chol_cov,
5
+ )
6
+ from cuthbertlib.linearize.moments import linearize_moments
7
+ from cuthbertlib.linearize.taylor import linearize_taylor
@@ -0,0 +1,175 @@
1
+ """Implements linearization of conditional log densities."""
2
+
3
+ from typing import overload
4
+
5
+ import jax.numpy as jnp
6
+ from jax import grad, hessian, jacobian
7
+
8
+ from cuthbertlib.linalg import chol_cov_with_nans_to_cov, symmetric_inv_sqrt
9
+ from cuthbertlib.types import (
10
+ Array,
11
+ ArrayLike,
12
+ ArrayTree,
13
+ LogConditionalDensity,
14
+ LogConditionalDensityAux,
15
+ )
16
+
17
+
18
+ @overload
19
+ def linearize_log_density(
20
+ log_density: LogConditionalDensity,
21
+ x: ArrayLike,
22
+ y: ArrayLike,
23
+ has_aux: bool = False,
24
+ rtol: float | None = None,
25
+ ignore_nan_dims: bool = False,
26
+ ) -> tuple[Array, Array, Array]: ...
27
+ @overload
28
+ def linearize_log_density(
29
+ log_density: LogConditionalDensityAux,
30
+ x: ArrayLike,
31
+ y: ArrayLike,
32
+ has_aux: bool = True,
33
+ rtol: float | None = None,
34
+ ignore_nan_dims: bool = False,
35
+ ) -> tuple[Array, Array, Array, ArrayTree]: ...
36
+
37
+
38
+ def linearize_log_density(
39
+ log_density: LogConditionalDensity | LogConditionalDensityAux,
40
+ x: ArrayLike,
41
+ y: ArrayLike,
42
+ has_aux: bool = False,
43
+ rtol: float | None = None,
44
+ ignore_nan_dims: bool = False,
45
+ ) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
46
+ r"""Linearizes a conditional log density around given points.
47
+
48
+ The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
49
+ $(H, d, L)$ if `log_density` is of the form
50
+
51
+ $$
52
+ \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const}.
53
+ $$
54
+
55
+ The Cholesky factor of the covariance is calculated using the negative Hessian
56
+ of `log_density` with respect to `y` as the precision matrix.
57
+ `symmetric_inv_sqrt` is used to calculate the inverse square root by
58
+ ignoring any singular values that are sufficiently close to zero
59
+ (this is a projection in the case the Hessian is not positive definite).
60
+
61
+ Alternatively, the Cholesky factor can be provided directly
62
+ in `linearize_log_density_given_chol_cov`.
63
+
64
+ Args:
65
+ log_density: A conditional log density of y given x. Returns a scalar.
66
+ x: The input points.
67
+ y: The output points.
68
+ has_aux: Whether `log_density` returns an auxiliary value.
69
+ rtol: The relative tolerance for the singular values of the precision matrix
70
+ when passed to `symmetric_inv_sqrt`.
71
+ Cutoff for small singular values; singular values smaller than
72
+ `rtol * largest_singular_value` are treated as zero.
73
+ The default is determined based on the floating point precision of the dtype.
74
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
75
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
76
+ precision matrix as missing and ignore all rows and columns associated with
77
+ them.
78
+
79
+ Returns:
80
+ Linearized matrix, shift, and Cholesky factor of the covariance matrix.
81
+ The auxiliary value is also returned if `has_aux` is `True`.
82
+ """
83
+ prec_and_maybe_aux = hessian(log_density, 1, has_aux=has_aux)(x, y)
84
+ prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux
85
+ if ignore_nan_dims:
86
+ prec_diag = jnp.diag(prec)
87
+ nan_mask = jnp.isnan(y) | jnp.isnan(prec_diag)
88
+ prec = prec.at[jnp.diag_indices_from(prec)].set(
89
+ jnp.where(nan_mask, jnp.nan, prec_diag)
90
+ )
91
+
92
+ chol_cov = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)
93
+ mat, shift, *extra = linearize_log_density_given_chol_cov(
94
+ log_density, x, y, chol_cov, has_aux=has_aux, ignore_nan_dims=ignore_nan_dims
95
+ )
96
+ return mat, shift, chol_cov, *extra
97
+
98
+
99
+ @overload
100
+ def linearize_log_density_given_chol_cov(
101
+ log_density: LogConditionalDensity,
102
+ x: ArrayLike,
103
+ y: ArrayLike,
104
+ chol_cov: ArrayLike,
105
+ has_aux: bool = False,
106
+ ignore_nan_dims: bool = False,
107
+ ) -> tuple[Array, Array]: ...
108
+ @overload
109
+ def linearize_log_density_given_chol_cov(
110
+ log_density: LogConditionalDensityAux,
111
+ x: ArrayLike,
112
+ y: ArrayLike,
113
+ chol_cov: ArrayLike,
114
+ has_aux: bool = True,
115
+ ignore_nan_dims: bool = False,
116
+ ) -> tuple[Array, Array, ArrayTree]: ...
117
+
118
+
119
+ def linearize_log_density_given_chol_cov(
120
+ log_density: LogConditionalDensity | LogConditionalDensityAux,
121
+ x: ArrayLike,
122
+ y: ArrayLike,
123
+ chol_cov: ArrayLike,
124
+ has_aux: bool = False,
125
+ ignore_nan_dims: bool = False,
126
+ ) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
127
+ r"""Linearizes a conditional log density around given points.
128
+
129
+ The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
130
+ $(H, d)$ if `log_density` is of the form
131
+
132
+ $$
133
+ \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const},
134
+ $$
135
+
136
+ where $L$ is the argument `chol_cov`.
137
+
138
+ Args:
139
+ log_density: A conditional log density of y given x. Returns a scalar.
140
+ x: The input points.
141
+ y: The output points.
142
+ chol_cov: The Cholesky factor of the covariance matrix of the Gaussian.
143
+ has_aux: Whether `log_density` returns an auxiliary value.
144
+ ignore_nan_dims: Whether to ignore dimensions with NaN on the diagonal of the
145
+ precision matrix or in y.
146
+
147
+ Returns:
148
+ Linearized matrix and shift. The auxiliary value is also returned if `has_aux` is `True`.
149
+ """
150
+ chol_cov = jnp.asarray(chol_cov)
151
+
152
+ cov = (
153
+ chol_cov_with_nans_to_cov(chol_cov)
154
+ if ignore_nan_dims
155
+ else chol_cov @ chol_cov.T
156
+ )
157
+
158
+ if has_aux:
159
+
160
+ def grad_log_density_wrapper_aux(x, y):
161
+ g, aux = grad(log_density, 1, has_aux=True)(x, y)
162
+ return g, (g, aux)
163
+
164
+ jac, (g, *extra) = jacobian(grad_log_density_wrapper_aux, 0, has_aux=True)(x, y)
165
+ else:
166
+
167
+ def grad_log_density_wrapper(x, y):
168
+ g = grad(log_density, 1)(x, y)
169
+ return g, (g,)
170
+
171
+ jac, (g, *extra) = jacobian(grad_log_density_wrapper, 0, has_aux=True)(x, y)
172
+
173
+ mat = cov @ jac
174
+ shift = y - mat @ x + cov @ g
175
+ return mat, shift, *extra
@@ -0,0 +1,94 @@
1
+ """Implements moment-based linearization."""
2
+
3
+ from typing import Callable, cast, overload
4
+
5
+ import jax
6
+ from jax.typing import ArrayLike
7
+
8
+ from cuthbertlib.types import Array, ArrayTree
9
+
10
+ MeanAndCholCovFunc = Callable[[ArrayLike], tuple[Array, Array]]
11
+ MeanAndCholCovFuncAux = Callable[[ArrayLike], tuple[Array, Array, ArrayTree]]
12
+
13
+
14
+ @overload
15
+ def linearize_moments(
16
+ mean_and_chol_cov_function: MeanAndCholCovFunc,
17
+ x: ArrayLike,
18
+ has_aux: bool = False,
19
+ ) -> tuple[Array, Array, Array]: ...
20
+ @overload
21
+ def linearize_moments(
22
+ mean_and_chol_cov_function: MeanAndCholCovFuncAux,
23
+ x: ArrayLike,
24
+ has_aux: bool = True,
25
+ ) -> tuple[Array, Array, Array, ArrayTree]: ...
26
+
27
+
28
+ def linearize_moments(
29
+ mean_and_chol_cov_function: MeanAndCholCovFunc | MeanAndCholCovFuncAux,
30
+ x: ArrayLike,
31
+ has_aux: bool = False,
32
+ ) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
33
+ r"""Linearizes conditional mean and chol_cov functions into a linear-Gaussian form.
34
+
35
+ Takes a function `mean_and_chol_cov_function(x)` that returns the
36
+ conditional mean and Cholesky factor of the covariance matrix of the distribution
37
+ $p(y \mid x)$ for a given input `x`.
38
+
39
+ Returns $(H, d, L)$ defining a linear-Gaussian approximation to the conditional
40
+ distribution $p(y \mid x) \approx N(y \mid H x + d, L L^\top)$.
41
+
42
+ `mean_and_chol_cov_function` has the following signature with `has_aux` = False:
43
+ ```
44
+ m, chol = mean_and_chol_cov_function(x)
45
+ ```
46
+ or with `has_aux` = True:
47
+ ```
48
+ m, chol, aux = mean_and_chol_cov_function(x)
49
+ ```
50
+
51
+ Args:
52
+ mean_and_chol_cov_function: A callable that returns the conditional mean and
53
+ Cholesky factor of the covariance matrix of the distribution for a given
54
+ input.
55
+ x: The point to linearize around.
56
+ has_aux: Whether `mean_and_chol_cov_function` returns an auxiliary value.
57
+
58
+ Returns:
59
+ Linearized matrix, shift, and Cholesky factor of the covariance matrix.
60
+ The auxiliary value is also returned if `has_aux` is `True`.
61
+
62
+ References:
63
+ - [sqrt-parallel-smoothers](https://github.com/EEA-sensors/sqrt-parallel-smoothers/blob/main/parsmooth/linearization/_extended.py)
64
+ """
65
+ if has_aux:
66
+ mean_and_chol_cov_function = cast(
67
+ MeanAndCholCovFuncAux, mean_and_chol_cov_function
68
+ )
69
+
70
+ def mean_and_chol_cov_function_wrapper_aux(
71
+ x: ArrayLike,
72
+ ) -> tuple[Array, tuple[Array, Array, ArrayTree]]:
73
+ mean, chol_cov, aux = mean_and_chol_cov_function(x)
74
+ return mean, (mean, chol_cov, aux)
75
+
76
+ F, (m, *extra) = jax.jacfwd(
77
+ mean_and_chol_cov_function_wrapper_aux, has_aux=True
78
+ )(x)
79
+
80
+ else:
81
+ mean_and_chol_cov_function = cast(
82
+ MeanAndCholCovFunc, mean_and_chol_cov_function
83
+ )
84
+
85
+ def mean_and_chol_cov_function_wrapper(
86
+ x: ArrayLike,
87
+ ) -> tuple[Array, tuple[Array, Array]]:
88
+ mean, chol_cov = mean_and_chol_cov_function(x)
89
+ return mean, (mean, chol_cov)
90
+
91
+ F, (m, *extra) = jax.jacfwd(mean_and_chol_cov_function_wrapper, has_aux=True)(x)
92
+
93
+ b = m - F @ x
94
+ return F, b, *extra
@@ -0,0 +1,83 @@
1
+ """Implements Taylor-like linearization."""
2
+
3
+ from typing import Callable, overload
4
+
5
+ import jax
6
+ from jax import numpy as jnp
7
+ from jax.typing import ArrayLike
8
+
9
+ from cuthbertlib.linalg import symmetric_inv_sqrt
10
+ from cuthbertlib.types import Array, ArrayTree
11
+
12
+
13
+ @overload
14
+ def linearize_taylor(
15
+ log_potential: Callable[[ArrayLike], Array],
16
+ x: ArrayLike,
17
+ has_aux: bool = False,
18
+ rtol: float | None = None,
19
+ ignore_nan_dims: bool = False,
20
+ ) -> tuple[Array, Array]: ...
21
+ @overload
22
+ def linearize_taylor(
23
+ log_potential: Callable[[ArrayLike], tuple[Array, ArrayTree]],
24
+ x: ArrayLike,
25
+ has_aux: bool = True,
26
+ rtol: float | None = None,
27
+ ignore_nan_dims: bool = False,
28
+ ) -> tuple[Array, Array, ArrayTree]: ...
29
+
30
+
31
+ def linearize_taylor(
32
+ log_potential: Callable[[ArrayLike], Array]
33
+ | Callable[[ArrayLike], tuple[Array, ArrayTree]],
34
+ x: ArrayLike,
35
+ has_aux: bool = False,
36
+ rtol: float | None = None,
37
+ ignore_nan_dims: bool = False,
38
+ ) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
39
+ r"""Linearizes a log potential function around a given point using Taylor expansion.
40
+
41
+ Unlike the other linearization methods, this applies to a potential function
42
+ with no required notion of observation $y$ or conditional dependence.
43
+
44
+ Instead we have the linearization
45
+
46
+ $$
47
+ \log G(x) = -\frac{1}{2} (x - m)^\top (L L^\top)^{-1} (x - m).
48
+ $$
49
+
50
+ Args:
51
+ log_potential: A callable that returns a non-negative scalar. Does not need
52
+ to be a normalized probability density in its input.
53
+ x: The point to linearize around.
54
+ has_aux: Whether `log_potential` returns an auxiliary value.
55
+ rtol: The relative tolerance for the singular values of the precision matrix
56
+ when passed to `symmetric_inv_sqrt`.
57
+ Cutoff for small singular values; singular values smaller than
58
+ `rtol * largest_singular_value` are treated as zero.
59
+ The default is determined based on the floating point precision of the dtype.
60
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
61
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
62
+ precision matrix as missing and ignore all rows and columns associated with
63
+ them.
64
+
65
+ Returns:
66
+ Linearized mean and Cholesky factor of the covariance matrix.
67
+ The auxiliary value is also returned if `has_aux` is `True`.
68
+ """
69
+ g_and_maybe_aux = jax.grad(log_potential, has_aux=has_aux)(x)
70
+ prec_and_maybe_aux = jax.hessian(log_potential, has_aux=has_aux)(x)
71
+
72
+ g, aux = g_and_maybe_aux if has_aux else (g_and_maybe_aux, None)
73
+ prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux
74
+
75
+ L = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)
76
+
77
+ # Change nans on diag to zeros for L @ L.T @ g, still retain nans on diag for L for bookkeeping
78
+ # If ignore_nan_dims, change all rows and columns with nans on the diagonal to 0
79
+ L_diag = jnp.diag(L)
80
+ nan_mask = jnp.isnan(L_diag) * ignore_nan_dims
81
+ L_temp = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0.0, L)
82
+ m = x + L_temp @ L_temp.T @ g
83
+ return (m, L, aux) if has_aux else (m, L)
@@ -0,0 +1,4 @@
1
+ from cuthbertlib.quadrature import cubature, gauss_hermite, unscented
2
+ from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
3
+ from cuthbertlib.quadrature.linearize import conditional_moments, functional
4
+ from cuthbertlib.quadrature.utils import cholesky_update_many
@@ -0,0 +1,102 @@
1
+ """Common types and protocols for quadrature."""
2
+
3
+ from typing import NamedTuple, Protocol, Self, runtime_checkable
4
+
5
+ import jax.numpy as jnp
6
+
7
+ from cuthbertlib.linalg import tria
8
+ from cuthbertlib.types import Array, ArrayLike
9
+
10
+ __all__ = ["SigmaPoints", "Quadrature"]
11
+
12
+
13
+ class SigmaPoints(NamedTuple):
14
+ """Represents integration (quadrature) sigma points as a collection of points.
15
+
16
+ Weights correspond to mean and covariance calculations.
17
+
18
+ Attributes:
19
+ points: The sigma points.
20
+ wm: The mean weights.
21
+ wc: The covariance weights.
22
+
23
+ Methods:
24
+ mean: Computes the mean of the sigma points.
25
+ covariance: Computes the covariance between the sigma points and the other
26
+ sigma points (or itself).
27
+ sqrt: Computes a square root of the covariance matrix of the sigma points.
28
+
29
+ References:
30
+ Simo Särkkä, Lennard Svensson. *Bayesian Filtering and Smoothing.*
31
+ In: Cambridge University Press 2023.
32
+ """
33
+
34
+ points: Array
35
+ wm: Array
36
+ wc: Array
37
+
38
+ @property
39
+ def mean(self) -> Array:
40
+ """Computes the mean of the sigma points.
41
+
42
+ Returns:
43
+ The mean of the sigma points.
44
+ """
45
+ return jnp.dot(self.wm, self.points)
46
+
47
+ # Should this be property too?
48
+ def covariance(self, other: Self | None = None) -> Array:
49
+ """Computes the covariance between the sigma points and the other sigma points.
50
+
51
+ Args:
52
+ other: The optional other sigma points.
53
+
54
+ Returns:
55
+ The covariance matrix.
56
+ """
57
+ mean = self.mean
58
+ if other is None:
59
+ return _cov(self.wc, self.points, mean, self.points, mean)
60
+
61
+ other_mean = other.mean
62
+ return _cov(self.wc, self.points, mean, other.points, other_mean)
63
+
64
+ @property
65
+ def sqrt(self) -> Array:
66
+ """Computes the square root of the covariance matrix of the sigma points.
67
+
68
+ Returns:
69
+ The square root of the covariance matrix.
70
+ """
71
+ sqrt = jnp.sqrt(self.wc[:, None]) * (self.points - self.mean[None, :])
72
+ sqrt = tria(sqrt.T)
73
+ return sqrt
74
+
75
+
76
+ @runtime_checkable
77
+ class Quadrature(Protocol):
78
+ """Protocol for quadrature methods."""
79
+
80
+ def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
81
+ """Get the sigma points.
82
+
83
+ Args:
84
+ m: The mean.
85
+ chol: The Cholesky factor of the covariance.
86
+
87
+ Returns:
88
+ SigmaPoints: The sigma points.
89
+ """
90
+ ...
91
+
92
+
93
+ def _cov(
94
+ wc: Array,
95
+ x_pts: Array,
96
+ x_mean: Array,
97
+ y_points: Array,
98
+ y_mean: Array,
99
+ ) -> Array:
100
+ one = (x_pts - x_mean[None, :]).T * wc[None, :]
101
+ two = y_points - y_mean[None, :]
102
+ return jnp.dot(one, two)
@@ -0,0 +1,73 @@
1
+ """Implements cubature quadrature."""
2
+
3
+ from typing import NamedTuple
4
+
5
+ import jax.numpy as jnp
6
+ import numpy as np
7
+ from jax.typing import ArrayLike
8
+
9
+ from cuthbertlib.quadrature.common import Quadrature, SigmaPoints
10
+
11
+ __all__ = ["weights", "CubatureQuadrature"]
12
+
13
+
14
+ class CubatureQuadrature(NamedTuple):
15
+ """Cubature quadrature.
16
+
17
+ Attributes:
18
+ wm: The mean weights.
19
+ wc: The covariance weights.
20
+ xi: The sigma points.
21
+ """
22
+
23
+ wm: ArrayLike
24
+ wc: ArrayLike
25
+ xi: ArrayLike
26
+
27
+ def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
28
+ """Get the sigma points.
29
+
30
+ Args:
31
+ m: The mean.
32
+ chol: The Cholesky factor of the covariance.
33
+
34
+ Returns:
35
+ SigmaPoints: The sigma points.
36
+ """
37
+ return get_sigma_points(m, chol, self.xi, self.wm, self.wc)
38
+
39
+
40
+ def get_sigma_points(
41
+ m: ArrayLike, chol: ArrayLike, xi: ArrayLike, wm: ArrayLike, wc: ArrayLike
42
+ ) -> SigmaPoints:
43
+ # TODO: Add docstring here
44
+ m = jnp.asarray(m)
45
+ chol = jnp.asarray(chol)
46
+ xi = jnp.asarray(xi)
47
+ wm = jnp.asarray(wm)
48
+ wc = jnp.asarray(wc)
49
+ sigma_points = m[None, :] + jnp.dot(chol, xi.T).T
50
+
51
+ return SigmaPoints(sigma_points, wm, wc)
52
+
53
+
54
+ def weights(n_dim: int) -> Quadrature:
55
+ """Computes the weights associated with the spherical cubature method.
56
+
57
+ The number of sigma-points is 2 * n_dim.
58
+
59
+ Args:
60
+ n_dim: Dimensionality of the problem.
61
+
62
+ Returns:
63
+ The quadrature object with the weights and sigma-points.
64
+
65
+ References:
66
+ Simo Särkkä, Lennard Svensson. *Bayesian Filtering and Smoothing.*
67
+ In: Cambridge University Press 2023.
68
+ """
69
+ wm = np.ones(shape=(2 * n_dim,)) / (2 * n_dim)
70
+ wc = wm
71
+ xi = np.concatenate([np.eye(n_dim), -np.eye(n_dim)], axis=0) * np.sqrt(n_dim)
72
+
73
+ return CubatureQuadrature(wm=wm, wc=wc, xi=xi)