cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. cuthbertlib/discrete/__init__.py +0 -0
  29. cuthbertlib/discrete/filtering.py +49 -0
  30. cuthbertlib/discrete/smoothing.py +35 -0
  31. cuthbertlib/kalman/__init__.py +4 -0
  32. cuthbertlib/kalman/filtering.py +213 -0
  33. cuthbertlib/kalman/generate.py +85 -0
  34. cuthbertlib/kalman/sampling.py +68 -0
  35. cuthbertlib/kalman/smoothing.py +121 -0
  36. cuthbertlib/linalg/__init__.py +7 -0
  37. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  38. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  39. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  40. cuthbertlib/linalg/tria.py +21 -0
  41. cuthbertlib/linearize/__init__.py +7 -0
  42. cuthbertlib/linearize/log_density.py +175 -0
  43. cuthbertlib/linearize/moments.py +94 -0
  44. cuthbertlib/linearize/taylor.py +83 -0
  45. cuthbertlib/quadrature/__init__.py +4 -0
  46. cuthbertlib/quadrature/common.py +102 -0
  47. cuthbertlib/quadrature/cubature.py +73 -0
  48. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  49. cuthbertlib/quadrature/linearize.py +143 -0
  50. cuthbertlib/quadrature/unscented.py +79 -0
  51. cuthbertlib/quadrature/utils.py +109 -0
  52. cuthbertlib/resampling/__init__.py +3 -0
  53. cuthbertlib/resampling/killing.py +79 -0
  54. cuthbertlib/resampling/multinomial.py +53 -0
  55. cuthbertlib/resampling/protocols.py +92 -0
  56. cuthbertlib/resampling/systematic.py +78 -0
  57. cuthbertlib/resampling/utils.py +82 -0
  58. cuthbertlib/smc/__init__.py +0 -0
  59. cuthbertlib/smc/ess.py +24 -0
  60. cuthbertlib/smc/smoothing/__init__.py +0 -0
  61. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  62. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  63. cuthbertlib/smc/smoothing/protocols.py +44 -0
  64. cuthbertlib/smc/smoothing/tracing.py +45 -0
  65. cuthbertlib/stats/__init__.py +0 -0
  66. cuthbertlib/stats/multivariate_normal.py +102 -0
  67. cuthbert-0.0.2.dist-info/RECORD +0 -12
  68. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  69. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
  70. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,35 @@
1
+ """Implements the discrete HMM smoothing associative operator."""
2
+
3
+ import jax.numpy as jnp
4
+
5
+ from cuthbertlib.types import Array, ArrayLike
6
+
7
+
8
+ def get_reverse_kernel(x_t_dist: ArrayLike, trans_matrix: ArrayLike) -> Array:
9
+ r"""Computes reverse transition probabilities $p(x_{t-1} \mid x_{t}, \dots)$ for a discrete HMM.
10
+
11
+ Args:
12
+ x_t_dist: Array of shape (N,) where `x_t_dist[i]` = $p(x_{t} = i \mid \dots)$.
13
+ trans_matrix: Array of shape (N, N) where
14
+ `trans_matrix[i, j]` = $p(x_{t} = j \mid x_{t-1} = i)$.
15
+
16
+ Returns:
17
+ An (N, N) matrix `x_tm1_dist[i, j]` = $p(x_{t-1} = j \mid x_{t} = i, \dots)$.
18
+ """
19
+ x_t_dist, trans_matrix = jnp.asarray(x_t_dist), jnp.asarray(trans_matrix)
20
+ pred = jnp.dot(trans_matrix.T, x_t_dist)
21
+ x_tm1_dist = trans_matrix.T * x_t_dist[None, :] / pred[:, None]
22
+ return x_tm1_dist
23
+
24
+
25
+ def smoothing_operator(elem_ij: Array, elem_jk: Array) -> Array:
26
+ """Binary associative operator for smoothing in discrete HMMs.
27
+
28
+ Args:
29
+ elem_ij: Smoothing scan element.
30
+ elem_jk: Smoothing scan element.
31
+
32
+ Returns:
33
+ The output of the associative operator applied to the input elements.
34
+ """
35
+ return elem_jk @ elem_ij
@@ -0,0 +1,4 @@
1
+ from cuthbertlib.kalman import generate
2
+ from cuthbertlib.kalman.filtering import predict
3
+ from cuthbertlib.kalman.filtering import update as filter_update
4
+ from cuthbertlib.kalman.smoothing import update as smoother_update
@@ -0,0 +1,213 @@
1
+ """Implements the square root parallel Kalman filter and associative variant."""
2
+
3
+ from typing import NamedTuple
4
+
5
+ import jax.numpy as jnp
6
+ from jax.scipy.linalg import cho_solve, solve_triangular
7
+
8
+ from cuthbertlib.linalg import tria
9
+ from cuthbertlib.stats import multivariate_normal
10
+ from cuthbertlib.stats.multivariate_normal import collect_nans_chol
11
+ from cuthbertlib.types import Array, ArrayLike, ScalarArray
12
+
13
+
14
+ class FilterScanElement(NamedTuple):
15
+ """Arrays carried through the Kalman filter scan."""
16
+
17
+ A: Array
18
+ b: Array
19
+ U: Array
20
+ eta: Array
21
+ Z: Array
22
+ ell: ScalarArray
23
+
24
+
25
+ def predict(
26
+ m: ArrayLike,
27
+ chol_P: ArrayLike,
28
+ F: ArrayLike,
29
+ c: ArrayLike,
30
+ chol_Q: ArrayLike,
31
+ ) -> tuple[Array, Array]:
32
+ """Propagate the mean and square root covariance through linear Gaussian dynamics.
33
+
34
+ Args:
35
+ m: Mean of the state.
36
+ chol_P: Generalized Cholesky factor of the covariance of the state.
37
+ F: Transition matrix.
38
+ c: Transition shift.
39
+ chol_Q: Generalized Cholesky factor of the transition noise covariance.
40
+
41
+ Returns:
42
+ Propagated mean and square root covariance.
43
+
44
+ References:
45
+ Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
46
+ Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
47
+ """
48
+ m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
49
+ F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)
50
+ m1 = F @ m + c
51
+ A = jnp.concatenate([F @ chol_P, chol_Q], axis=1)
52
+ chol_P1 = tria(A)
53
+ return m1, chol_P1
54
+
55
+
56
+ # Note that `update` is aliased as `kalman.filter_update` in `kalman.__init__.py`
57
+ def update(
58
+ m: ArrayLike,
59
+ chol_P: ArrayLike,
60
+ H: ArrayLike,
61
+ d: ArrayLike,
62
+ chol_R: ArrayLike,
63
+ y: ArrayLike,
64
+ log_normalizing_constant: ArrayLike = 0.0,
65
+ ) -> tuple[tuple[Array, Array], Array]:
66
+ """Update the mean and square root covariance with a linear Gaussian observation.
67
+
68
+ Args:
69
+ m: Mean of the state.
70
+ chol_P: Generalized Cholesky factor of the covariance of the state.
71
+ H: Observation matrix.
72
+ d: Observation shift.
73
+ chol_R: Generalized Cholesky factor of the observation noise covariance.
74
+ y: Observation.
75
+ log_normalizing_constant: Optional input of log normalizing constant to be added to
76
+ log normalizing constant of the Bayesian update.
77
+
78
+ Returns:
79
+ Updated mean and square root covariance as well as the log marginal likelihood.
80
+
81
+ References:
82
+ Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
83
+ Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
84
+ """
85
+ # Handle case where there is no observation
86
+ flag = jnp.isnan(y)
87
+ flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)
88
+
89
+ m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
90
+ H, d, chol_R = jnp.asarray(H), jnp.asarray(d), jnp.asarray(chol_R)
91
+ y = jnp.asarray(y)
92
+
93
+ n_y, n_x = H.shape
94
+
95
+ y_hat = H @ m + d
96
+ y_diff = y - y_hat
97
+
98
+ M = jnp.block(
99
+ [
100
+ [H @ chol_P, chol_R],
101
+ [chol_P, jnp.zeros((n_x, n_y), dtype=chol_P.dtype)],
102
+ ]
103
+ )
104
+ chol_S = tria(M)
105
+ chol_Py = chol_S[n_y:, n_y:]
106
+
107
+ Gmat = chol_S[n_y:, :n_y]
108
+ Imat = chol_S[:n_y, :n_y]
109
+
110
+ my = m + Gmat @ solve_triangular(Imat, y_diff, lower=True)
111
+
112
+ ell = multivariate_normal.logpdf(y, y_hat, Imat, nan_support=False)
113
+ return (my, chol_Py), jnp.asarray(ell + log_normalizing_constant)
114
+
115
+
116
+ def associative_params_single(
117
+ F: Array, c: Array, chol_Q: Array, H: Array, d: Array, chol_R: Array, y: Array
118
+ ) -> FilterScanElement:
119
+ """Single time step for scan element for square root parallel Kalman filter.
120
+
121
+ Args:
122
+ F: State transition matrix.
123
+ c: State transition shift vector.
124
+ chol_Q: Generalized Cholesky factor of the state transition noise covariance.
125
+ H: Observation matrix.
126
+ d: Observation shift.
127
+ chol_R: Generalized Cholesky factor of the observation noise covariance.
128
+ y: Observation.
129
+
130
+ Returns:
131
+ Prepared scan element for the square root parallel Kalman filter.
132
+ """
133
+ # Handle case where there is no observation
134
+ flag = jnp.isnan(y)
135
+ flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)
136
+
137
+ ny, nx = H.shape
138
+
139
+ # joint over the predictive and the observation
140
+ Psi_ = jnp.block([[H @ chol_Q, chol_R], [chol_Q, jnp.zeros((nx, ny))]])
141
+
142
+ Tria_Psi_ = tria(Psi_)
143
+
144
+ Psi11 = Tria_Psi_[:ny, :ny]
145
+ Psi21 = Tria_Psi_[ny : ny + nx, :ny]
146
+ U = Tria_Psi_[ny : ny + nx, ny:]
147
+
148
+ # pre-compute inverse of Psi11: we apply it to matrices and vectors alike.
149
+ Psi11_inv = solve_triangular(Psi11, jnp.eye(ny), lower=True)
150
+
151
+ # predictive model given one observation
152
+ K = Psi21 @ Psi11_inv # local Kalman gain
153
+ HF = H @ F # temporary variable
154
+ A = F - K @ HF # corrected transition matrix
155
+
156
+ b = c + K @ (y - H @ c - d) # corrected transition offset
157
+
158
+ # information filter
159
+ Z = HF.T @ Psi11_inv.T
160
+ eta = Psi11_inv @ (y - H @ c - d)
161
+ eta = Z @ eta
162
+
163
+ if nx > ny:
164
+ Z = jnp.concatenate([Z, jnp.zeros((nx, nx - ny))], axis=1)
165
+ else:
166
+ Z = tria(Z)
167
+
168
+ # local log marginal likelihood
169
+ ell = jnp.asarray(
170
+ multivariate_normal.logpdf(y, H @ c + d, Psi11, nan_support=False)
171
+ )
172
+
173
+ return FilterScanElement(A, b, U, eta, Z, ell)
174
+
175
+
176
+ def filtering_operator(
177
+ elem_i: FilterScanElement, elem_j: FilterScanElement
178
+ ) -> FilterScanElement:
179
+ """Binary associative operator for the square root Kalman filter.
180
+
181
+ Args:
182
+ elem_i: Filter scan element for the previous time step.
183
+ elem_j: Filter scan element for the current time step.
184
+
185
+ Returns:
186
+ FilterScanElement: The output of the associative operator applied to the input elements.
187
+ """
188
+ A1, b1, U1, eta1, Z1, ell1 = elem_i
189
+ A2, b2, U2, eta2, Z2, ell2 = elem_j
190
+
191
+ nx = Z2.shape[0]
192
+
193
+ Xi = jnp.block([[U1.T @ Z2, jnp.eye(nx)], [Z2, jnp.zeros_like(A1)]])
194
+ tria_xi = tria(Xi)
195
+ Xi11 = tria_xi[:nx, :nx]
196
+ Xi21 = tria_xi[nx : nx + nx, :nx]
197
+ Xi22 = tria_xi[nx : nx + nx, nx:]
198
+
199
+ tmp_1 = solve_triangular(Xi11, U1.T, lower=True).T
200
+ D_inv = jnp.eye(nx) - tmp_1 @ Xi21.T
201
+ tmp_2 = D_inv @ (b1 + U1 @ (U1.T @ eta2))
202
+
203
+ A = A2 @ D_inv @ A1
204
+ b = A2 @ tmp_2 + b2
205
+ U = tria(jnp.concatenate([A2 @ tmp_1, U2], axis=1))
206
+ eta = A1.T @ (D_inv.T @ (eta2 - Z2 @ (Z2.T @ b1))) + eta1
207
+ Z = tria(jnp.concatenate([A1.T @ Xi22, Z1], axis=1))
208
+
209
+ mu = cho_solve((U1, True), b1)
210
+ t1 = b1 @ mu - (eta2 + mu) @ tmp_2
211
+ ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]
212
+
213
+ return FilterScanElement(A, b, U, eta, Z, ell)
@@ -0,0 +1,85 @@
1
+ """Utilities to generate linear-Gaussian state-space models (LGSSMs)."""
2
+
3
+ from functools import partial
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import Array, random
8
+
9
+ from cuthbertlib.types import KeyArray
10
+
11
+
12
+ @partial(jax.jit, static_argnames=("x_dim", "y_dim", "num_time_steps"))
13
+ def generate_lgssm(seed: int, x_dim: int, y_dim: int, num_time_steps: int):
14
+ """Generates an LGSSM along with a set of observations."""
15
+ key = random.key(seed)
16
+
17
+ key, init_key, sample_key, obs_model_key, obs_key = random.split(key, 5)
18
+ m0, chol_P0 = generate_init_model(init_key, x_dim)
19
+ x0 = m0 + chol_P0 @ random.normal(sample_key, (x_dim,))
20
+
21
+ # Generate an observation for time 0
22
+ H0, d0, chol_R0 = generate_obs_model(obs_model_key, x_dim, y_dim)
23
+ obs_noise = chol_R0 @ random.normal(obs_key, (y_dim,))
24
+ y0 = H0 @ x0 + d0 + obs_noise
25
+
26
+ def body(_x, _key):
27
+ trans_model_key, trans_key, obs_model_key, obs_key = random.split(_key, 4)
28
+
29
+ F, c, chol_Q = generate_trans_model(trans_model_key, x_dim)
30
+ state_noise = chol_Q @ random.normal(trans_key, (x_dim,))
31
+ x = F @ _x + c + state_noise
32
+
33
+ H, d, chol_R = generate_obs_model(obs_model_key, x_dim, y_dim)
34
+ obs_noise = chol_R @ random.normal(obs_key, (y_dim,))
35
+ y = H @ x + d + obs_noise
36
+
37
+ return x, (F, c, chol_Q, H, d, chol_R, y)
38
+
39
+ scan_keys = random.split(key, num_time_steps)
40
+ _, (Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys) = jax.lax.scan(body, x0, scan_keys)
41
+
42
+ # Prepend the observation model parameters and observation for t=0
43
+ Hs, ds, chol_Rs, ys = jax.tree.map(
44
+ lambda x, xs: jnp.concatenate([x[None], xs]),
45
+ (H0, d0, chol_R0, y0),
46
+ (Hs, ds, chol_Rs, ys),
47
+ )
48
+
49
+ return m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys
50
+
51
+
52
+ def generate_cholesky_factor(key: KeyArray, dim: int) -> Array:
53
+ """Generates a random Cholesky factor (lower-triangular matrix)."""
54
+ chol_A = random.uniform(key, (dim, dim))
55
+ chol_A = chol_A.at[jnp.triu_indices(dim, 1)].set(0.0)
56
+ return chol_A
57
+
58
+
59
+ def generate_init_model(key: KeyArray, x_dim: int) -> tuple[Array, Array]:
60
+ """Generates a random initial state for an LGSSM."""
61
+ keys = random.split(key)
62
+ m0 = random.normal(keys[0], (x_dim,))
63
+ chol_P0 = generate_cholesky_factor(keys[1], x_dim)
64
+ return m0, chol_P0
65
+
66
+
67
+ def generate_trans_model(key: KeyArray, x_dim: int) -> tuple[Array, Array, Array]:
68
+ """Generates a random transition model for an LGSSM."""
69
+ keys = random.split(key, 3)
70
+ exp_eig_max = 0.75 # Chosen less than one to stop exploding states (in expectation)
71
+ F = exp_eig_max * random.normal(keys[0], (x_dim, x_dim)) / jnp.sqrt(x_dim)
72
+ c = 0.1 * random.normal(keys[1], (x_dim,))
73
+ chol_Q = generate_cholesky_factor(keys[2], x_dim)
74
+ return F, c, chol_Q
75
+
76
+
77
+ def generate_obs_model(
78
+ key: KeyArray, x_dim: int, y_dim: int
79
+ ) -> tuple[Array, Array, Array]:
80
+ """Generates a random observation model for an LGSSM."""
81
+ keys = random.split(key, 3)
82
+ H = random.normal(keys[0], (y_dim, x_dim))
83
+ d = random.normal(keys[1], (y_dim,))
84
+ chol_R = generate_cholesky_factor(keys[2], y_dim)
85
+ return H, d, chol_R
@@ -0,0 +1,68 @@
1
+ """Implements the square root parallel Kalman associative operator for sampling.
2
+
3
+ Samples from the smoothing distribution without doing the smoothing scan for means
4
+ and (chol) covariances.
5
+ """
6
+
7
+ from typing import NamedTuple, Sequence
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+
12
+ from cuthbertlib.kalman.smoothing import associative_params_single
13
+ from cuthbertlib.types import Array, ArrayLike
14
+
15
+
16
+ class SamplerScanElement(NamedTuple):
17
+ """Kalman sampling scan element."""
18
+
19
+ gain: Array
20
+ sample: Array
21
+
22
+
23
+ def sqrt_associative_params(
24
+ key: ArrayLike,
25
+ ms: Array,
26
+ chol_Ps: Array,
27
+ Fs: Array,
28
+ cs: Array,
29
+ chol_Qs: Array,
30
+ shape: Sequence[int],
31
+ ) -> SamplerScanElement:
32
+ """Compute the sampler scan elements."""
33
+ shape = tuple(shape)
34
+ eps = jax.random.normal(key, ms.shape[:1] + shape + ms.shape[1:])
35
+ interm_elems = jax.vmap(_sqrt_associative_params_interm)(
36
+ ms[:-1], chol_Ps[:-1], Fs, cs, chol_Qs, eps[:-1]
37
+ )
38
+ last_elem = _sqrt_associative_params_final(ms[-1], chol_Ps[-1], eps[-1])
39
+ return jax.tree.map(
40
+ lambda x, y: jnp.concatenate([x, y[None]]), interm_elems, last_elem
41
+ )
42
+
43
+
44
+ def _sqrt_associative_params_interm(
45
+ m: Array, chol_P: Array, F: Array, c: Array, chol_Q: Array, eps: Array
46
+ ) -> SamplerScanElement:
47
+ inc_m, gain, L = associative_params_single(m, chol_P, F, c, chol_Q)
48
+ inc = inc_m + eps @ L.T
49
+ return SamplerScanElement(gain, inc)
50
+
51
+
52
+ def _sqrt_associative_params_final(
53
+ m: Array, chol_P: Array, eps: Array
54
+ ) -> SamplerScanElement:
55
+ gain = jnp.zeros_like(chol_P)
56
+ sample = m + eps @ chol_P.T
57
+ return SamplerScanElement(gain, sample)
58
+
59
+
60
+ def sampling_operator(
61
+ elem_i: SamplerScanElement, elem_j: SamplerScanElement
62
+ ) -> SamplerScanElement:
63
+ """Binary associative operator for sampling."""
64
+ G_i, e_i = elem_i
65
+ G_j, e_j = elem_j
66
+ G = G_j @ G_i
67
+ e = e_i @ G_j.T + e_j
68
+ return SamplerScanElement(G, e)
@@ -0,0 +1,121 @@
1
+ """Implements the square root Rauch–Tung–Striebel (RTS) smoother and associative variant."""
2
+
3
+ from typing import NamedTuple
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax.scipy.linalg import solve_triangular
8
+
9
+ from cuthbertlib.linalg import tria
10
+ from cuthbertlib.types import Array, ArrayLike
11
+
12
+
13
+ class SmootherScanElement(NamedTuple):
14
+ """Kalman smoother scan element."""
15
+
16
+ g: Array
17
+ E: Array
18
+ D: Array
19
+
20
+
21
+ # Note that `update` is aliased as `kalman.smoother_update` in `kalman.__init__.py`
22
+ def update(
23
+ filter_m: ArrayLike,
24
+ filter_chol_P: ArrayLike,
25
+ smoother_m: ArrayLike,
26
+ smoother_chol_P: ArrayLike,
27
+ F: ArrayLike,
28
+ c: ArrayLike,
29
+ chol_Q: ArrayLike,
30
+ ) -> tuple[tuple[Array, Array], Array]:
31
+ """Single step of the square root Rauch–Tung–Striebel (RTS) smoother.
32
+
33
+ Args:
34
+ filter_m: Mean of the filtered state.
35
+ filter_chol_P: Generalized Cholesky factor of the filtering covariance.
36
+ smoother_m: Mean of the smoother state.
37
+ smoother_chol_P: Generalized Cholesky factor of the smoothing covariance.
38
+ F: State transition matrix.
39
+ c: State transition shift vector.
40
+ chol_Q: Generalized Cholesky factor of the state transition noise covariance.
41
+
42
+ Returns:
43
+ A tuple `(smooth_state, info)`.
44
+ `smooth_state` contains the smoothed mean and square root covariance.
45
+ `info` contains the smoothing gain matrix.
46
+
47
+ References:
48
+ Paper: Park and Kailath (1994) - Square-root RTS smoothing algorithms
49
+ Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
50
+ """
51
+ filter_m, filter_chol_P = jnp.asarray(filter_m), jnp.asarray(filter_chol_P)
52
+ smoother_m, smoother_chol_P = jnp.asarray(smoother_m), jnp.asarray(smoother_chol_P)
53
+ F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)
54
+
55
+ nx = F.shape[0]
56
+ Phi = jnp.block([[F @ filter_chol_P, chol_Q], [filter_chol_P, jnp.zeros_like(F)]])
57
+ tria_Phi = tria(Phi)
58
+ Phi11 = tria_Phi[:nx, :nx]
59
+ Phi21 = tria_Phi[nx:, :nx]
60
+ Phi22 = tria_Phi[nx:, nx:]
61
+ gain = solve_triangular(Phi11, Phi21.T, trans=True, lower=True).T
62
+
63
+ mean_diff = smoother_m - (c + F @ filter_m)
64
+ mean = filter_m + gain @ mean_diff
65
+ chol = tria(jnp.concatenate([Phi22, gain @ smoother_chol_P], axis=1))
66
+ return (mean, chol), gain
67
+
68
+
69
+ def associative_params_single(
70
+ m: Array,
71
+ chol_P: Array,
72
+ F: Array,
73
+ c: Array,
74
+ chol_Q: Array,
75
+ ) -> SmootherScanElement:
76
+ """Single time step for scan element for square root parallel Kalman smoother.
77
+
78
+ Args:
79
+ m: Mean of the smoother state.
80
+ chol_P: Generalized Cholesky factor of the smoothing covariance.
81
+ F: State transition matrix.
82
+ c: State transition shift vector.
83
+ chol_Q: Generalized Cholesky factor of the state transition noise covariance.
84
+
85
+ Returns:
86
+ SmootherScanElement: The output of the associative operator applied to the input
87
+ elements.
88
+ """
89
+ nx = chol_Q.shape[0]
90
+
91
+ Phi = jnp.block([[F @ chol_P, chol_Q], [chol_P, jnp.zeros_like(chol_Q)]])
92
+ Tria_Phi = tria(Phi)
93
+ Phi11 = Tria_Phi[:nx, :nx]
94
+ Phi21 = Tria_Phi[nx:, :nx]
95
+ D = Tria_Phi[nx:, nx:]
96
+
97
+ E = jax.scipy.linalg.solve_triangular(Phi11.T, Phi21.T).T
98
+ g = m - E @ (F @ m + c)
99
+ return SmootherScanElement(g, E, D)
100
+
101
+
102
+ def smoothing_operator(
103
+ elem_i: SmootherScanElement, elem_j: SmootherScanElement
104
+ ) -> SmootherScanElement:
105
+ """Binary associative operator for the square root Kalman smoother.
106
+
107
+ Args:
108
+ elem_i: Smoother scan element.
109
+ elem_j: Smoother scan element.
110
+
111
+ Returns:
112
+ SmootherScanElement: The output of the associative operator applied to the input elements.
113
+ """
114
+ g_i, E_i, D_i = elem_i
115
+ g_j, E_j, D_j = elem_j
116
+
117
+ g = E_j @ g_i + g_j
118
+ E = E_j @ E_i
119
+ D = tria(jnp.concatenate([E_j @ D_i, D_j], axis=1))
120
+
121
+ return SmootherScanElement(g, E, D)
@@ -0,0 +1,7 @@
1
+ from cuthbertlib.linalg.collect_nans_chol import collect_nans_chol
2
+ from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov
3
+ from cuthbertlib.linalg.symmetric_inv_sqrt import (
4
+ chol_cov_with_nans_to_cov,
5
+ symmetric_inv_sqrt,
6
+ )
7
+ from cuthbertlib.linalg.tria import tria
@@ -0,0 +1,90 @@
1
+ """Implements collection of NaNs and reordering within a Cholesky factor."""
2
+
3
+ from typing import Any
4
+
5
+ from jax import numpy as jnp
6
+ from jax import tree
7
+
8
+ from cuthbertlib.linalg.tria import tria
9
+ from cuthbertlib.types import Array, ArrayLike
10
+
11
+
12
+ def _set_to_zero(flag: ArrayLike, x: ArrayLike) -> Array:
13
+ x = jnp.asarray(x)
14
+ broadcast_flag = jnp.expand_dims(flag, list(range(1, x.ndim)))
15
+ return jnp.where(broadcast_flag, 0.0, x)
16
+
17
+
18
+ def collect_nans_chol(flag: ArrayLike, chol: ArrayLike, *rest: Any) -> Any:
19
+ """Converts chol into an order chol factor with NaNs moved to the bottom right.
20
+
21
+ Specifically, converts a generalized Cholesky factor of a covariance matrix wit
22
+ NaNs into an ordered generalized Cholesky factor with NaNs rows and columns
23
+ moved to the end with diagonal elements set to 1.
24
+
25
+ Also reorders the rest of the arguments in the same way along the first axis
26
+ and sets to 0 for dimensions where flag is True.
27
+
28
+ Example behavior:
29
+ ```
30
+ flag = jnp.array([False, True, False, True])
31
+ new_flag, new_chol, new_mean = collect_nans_chol(flag, chol, mean)
32
+ ```
33
+
34
+ Args:
35
+ flag: Array, boolean array indicating which entries are NaN
36
+ True for NaN entries, False for valid
37
+ chol: Array, Cholesky factor of the covariance matrix
38
+ rest: Any, rest of the arguments to be reordered in the same way
39
+ along the first axis
40
+
41
+ Returns:
42
+ flag, chol and rest reordered so that valid entries are first and NaNs are last.
43
+ Diagonal elements of chol are set to 1/√2π so that normalization is correct
44
+ """
45
+ flag = jnp.asarray(flag)
46
+ chol = jnp.asarray(chol)
47
+
48
+ # TODO: Can we support batching? I.e. when `chol` is a batch of Cholesky factors,
49
+ # possibly with multiple leading dimensions
50
+
51
+ if flag.ndim > 1 or chol.ndim > 2:
52
+ raise ValueError("Batched flag or chol not supported yet")
53
+
54
+ if not flag.shape:
55
+ return (
56
+ flag,
57
+ _set_to_zero(flag, chol),
58
+ *tree.map(lambda x: _set_to_zero(flag, x), rest),
59
+ )
60
+
61
+ if chol.size == 1:
62
+ chol *= jnp.ones_like(flag, dtype=chol.dtype)
63
+
64
+ # group the NaN entries together
65
+ argsort = jnp.argsort(flag, stable=True)
66
+
67
+ if chol.ndim == 1:
68
+ chol = chol[argsort]
69
+ flag = flag[argsort]
70
+ chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), chol)
71
+
72
+ else:
73
+ chol = jnp.where(flag[:, None], 0.0, chol)
74
+ chol = chol[argsort]
75
+ # compute the tria of the covariance matrix with NaNs set to 0
76
+ chol = tria(chol)
77
+
78
+ flag = flag[argsort]
79
+
80
+ # set the diagonal of chol_cov to 1/√2π where nans were present so that normalization is correct
81
+ diag_chol = jnp.diag(chol)
82
+ diag_chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), diag_chol)
83
+ diag_indices = jnp.diag_indices_from(chol)
84
+ chol = chol.at[diag_indices].set(diag_chol)
85
+
86
+ # Only reorder non-scalar arrays in rest
87
+ rest = tree.map(lambda x: x[argsort] if jnp.asarray(x).shape else x, rest)
88
+ rest = tree.map(lambda x: _set_to_zero(flag, x), rest)
89
+
90
+ return flag, chol, *rest
@@ -0,0 +1,34 @@
1
+ """Extract marginal square root covariance from a joint square root covariance."""
2
+
3
+ from typing import Sequence
4
+
5
+ from jax import numpy as jnp
6
+
7
+ from cuthbertlib.linalg.tria import tria
8
+ from cuthbertlib.types import Array, ArrayLike
9
+
10
+
11
+ def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, end: int) -> Array:
12
+ """Extracts square root submatrix from a joint square root matrix.
13
+
14
+ Specifically, returns B such that
15
+ B @ B.T = (chol_cov @ chol_cov.T)[start:end, start:end]
16
+
17
+ Args:
18
+ chol_cov: Generalized Cholesky factor of the covariance matrix.
19
+ start: Start index of the submatrix.
20
+ end: End index of the submatrix.
21
+
22
+ Returns:
23
+ Lower triangular square root matrix of the marginal covariance matrix.
24
+ """
25
+ chol_cov = jnp.asarray(chol_cov)
26
+ assert chol_cov.ndim == 2, "chol_cov must be a 2D array"
27
+ assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square"
28
+ assert start >= 0 and end <= chol_cov.shape[0], (
29
+ "start and end must be within the bounds of chol_cov"
30
+ )
31
+ assert start < end, "start must be less than end"
32
+
33
+ chol_cov_select_rows = chol_cov[start:end, :]
34
+ return tria(chol_cov_select_rows)