cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.2.dist-info/RECORD +0 -12
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Implements the discrete HMM smoothing associative operator."""
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_reverse_kernel(x_t_dist: ArrayLike, trans_matrix: ArrayLike) -> Array:
|
|
9
|
+
r"""Computes reverse transition probabilities $p(x_{t-1} \mid x_{t}, \dots)$ for a discrete HMM.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
x_t_dist: Array of shape (N,) where `x_t_dist[i]` = $p(x_{t} = i \mid \dots)$.
|
|
13
|
+
trans_matrix: Array of shape (N, N) where
|
|
14
|
+
`trans_matrix[i, j]` = $p(x_{t} = j \mid x_{t-1} = i)$.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
An (N, N) matrix `x_tm1_dist[i, j]` = $p(x_{t-1} = j \mid x_{t} = i, \dots)$.
|
|
18
|
+
"""
|
|
19
|
+
x_t_dist, trans_matrix = jnp.asarray(x_t_dist), jnp.asarray(trans_matrix)
|
|
20
|
+
pred = jnp.dot(trans_matrix.T, x_t_dist)
|
|
21
|
+
x_tm1_dist = trans_matrix.T * x_t_dist[None, :] / pred[:, None]
|
|
22
|
+
return x_tm1_dist
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def smoothing_operator(elem_ij: Array, elem_jk: Array) -> Array:
|
|
26
|
+
"""Binary associative operator for smoothing in discrete HMMs.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
elem_ij: Smoothing scan element.
|
|
30
|
+
elem_jk: Smoothing scan element.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The output of the associative operator applied to the input elements.
|
|
34
|
+
"""
|
|
35
|
+
return elem_jk @ elem_ij
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""Implements the square root parallel Kalman filter and associative variant."""
|
|
2
|
+
|
|
3
|
+
from typing import NamedTuple
|
|
4
|
+
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
from jax.scipy.linalg import cho_solve, solve_triangular
|
|
7
|
+
|
|
8
|
+
from cuthbertlib.linalg import tria
|
|
9
|
+
from cuthbertlib.stats import multivariate_normal
|
|
10
|
+
from cuthbertlib.stats.multivariate_normal import collect_nans_chol
|
|
11
|
+
from cuthbertlib.types import Array, ArrayLike, ScalarArray
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FilterScanElement(NamedTuple):
|
|
15
|
+
"""Arrays carried through the Kalman filter scan."""
|
|
16
|
+
|
|
17
|
+
A: Array
|
|
18
|
+
b: Array
|
|
19
|
+
U: Array
|
|
20
|
+
eta: Array
|
|
21
|
+
Z: Array
|
|
22
|
+
ell: ScalarArray
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def predict(
|
|
26
|
+
m: ArrayLike,
|
|
27
|
+
chol_P: ArrayLike,
|
|
28
|
+
F: ArrayLike,
|
|
29
|
+
c: ArrayLike,
|
|
30
|
+
chol_Q: ArrayLike,
|
|
31
|
+
) -> tuple[Array, Array]:
|
|
32
|
+
"""Propagate the mean and square root covariance through linear Gaussian dynamics.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
m: Mean of the state.
|
|
36
|
+
chol_P: Generalized Cholesky factor of the covariance of the state.
|
|
37
|
+
F: Transition matrix.
|
|
38
|
+
c: Transition shift.
|
|
39
|
+
chol_Q: Generalized Cholesky factor of the transition noise covariance.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Propagated mean and square root covariance.
|
|
43
|
+
|
|
44
|
+
References:
|
|
45
|
+
Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
|
|
46
|
+
Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
|
|
47
|
+
"""
|
|
48
|
+
m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
|
|
49
|
+
F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)
|
|
50
|
+
m1 = F @ m + c
|
|
51
|
+
A = jnp.concatenate([F @ chol_P, chol_Q], axis=1)
|
|
52
|
+
chol_P1 = tria(A)
|
|
53
|
+
return m1, chol_P1
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Note that `update` is aliased as `kalman.filter_update` in `kalman.__init__.py`
|
|
57
|
+
def update(
|
|
58
|
+
m: ArrayLike,
|
|
59
|
+
chol_P: ArrayLike,
|
|
60
|
+
H: ArrayLike,
|
|
61
|
+
d: ArrayLike,
|
|
62
|
+
chol_R: ArrayLike,
|
|
63
|
+
y: ArrayLike,
|
|
64
|
+
log_normalizing_constant: ArrayLike = 0.0,
|
|
65
|
+
) -> tuple[tuple[Array, Array], Array]:
|
|
66
|
+
"""Update the mean and square root covariance with a linear Gaussian observation.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
m: Mean of the state.
|
|
70
|
+
chol_P: Generalized Cholesky factor of the covariance of the state.
|
|
71
|
+
H: Observation matrix.
|
|
72
|
+
d: Observation shift.
|
|
73
|
+
chol_R: Generalized Cholesky factor of the observation noise covariance.
|
|
74
|
+
y: Observation.
|
|
75
|
+
log_normalizing_constant: Optional input of log normalizing constant to be added to
|
|
76
|
+
log normalizing constant of the Bayesian update.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Updated mean and square root covariance as well as the log marginal likelihood.
|
|
80
|
+
|
|
81
|
+
References:
|
|
82
|
+
Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
|
|
83
|
+
Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
|
|
84
|
+
"""
|
|
85
|
+
# Handle case where there is no observation
|
|
86
|
+
flag = jnp.isnan(y)
|
|
87
|
+
flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)
|
|
88
|
+
|
|
89
|
+
m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
|
|
90
|
+
H, d, chol_R = jnp.asarray(H), jnp.asarray(d), jnp.asarray(chol_R)
|
|
91
|
+
y = jnp.asarray(y)
|
|
92
|
+
|
|
93
|
+
n_y, n_x = H.shape
|
|
94
|
+
|
|
95
|
+
y_hat = H @ m + d
|
|
96
|
+
y_diff = y - y_hat
|
|
97
|
+
|
|
98
|
+
M = jnp.block(
|
|
99
|
+
[
|
|
100
|
+
[H @ chol_P, chol_R],
|
|
101
|
+
[chol_P, jnp.zeros((n_x, n_y), dtype=chol_P.dtype)],
|
|
102
|
+
]
|
|
103
|
+
)
|
|
104
|
+
chol_S = tria(M)
|
|
105
|
+
chol_Py = chol_S[n_y:, n_y:]
|
|
106
|
+
|
|
107
|
+
Gmat = chol_S[n_y:, :n_y]
|
|
108
|
+
Imat = chol_S[:n_y, :n_y]
|
|
109
|
+
|
|
110
|
+
my = m + Gmat @ solve_triangular(Imat, y_diff, lower=True)
|
|
111
|
+
|
|
112
|
+
ell = multivariate_normal.logpdf(y, y_hat, Imat, nan_support=False)
|
|
113
|
+
return (my, chol_Py), jnp.asarray(ell + log_normalizing_constant)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def associative_params_single(
|
|
117
|
+
F: Array, c: Array, chol_Q: Array, H: Array, d: Array, chol_R: Array, y: Array
|
|
118
|
+
) -> FilterScanElement:
|
|
119
|
+
"""Single time step for scan element for square root parallel Kalman filter.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
F: State transition matrix.
|
|
123
|
+
c: State transition shift vector.
|
|
124
|
+
chol_Q: Generalized Cholesky factor of the state transition noise covariance.
|
|
125
|
+
H: Observation matrix.
|
|
126
|
+
d: Observation shift.
|
|
127
|
+
chol_R: Generalized Cholesky factor of the observation noise covariance.
|
|
128
|
+
y: Observation.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Prepared scan element for the square root parallel Kalman filter.
|
|
132
|
+
"""
|
|
133
|
+
# Handle case where there is no observation
|
|
134
|
+
flag = jnp.isnan(y)
|
|
135
|
+
flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)
|
|
136
|
+
|
|
137
|
+
ny, nx = H.shape
|
|
138
|
+
|
|
139
|
+
# joint over the predictive and the observation
|
|
140
|
+
Psi_ = jnp.block([[H @ chol_Q, chol_R], [chol_Q, jnp.zeros((nx, ny))]])
|
|
141
|
+
|
|
142
|
+
Tria_Psi_ = tria(Psi_)
|
|
143
|
+
|
|
144
|
+
Psi11 = Tria_Psi_[:ny, :ny]
|
|
145
|
+
Psi21 = Tria_Psi_[ny : ny + nx, :ny]
|
|
146
|
+
U = Tria_Psi_[ny : ny + nx, ny:]
|
|
147
|
+
|
|
148
|
+
# pre-compute inverse of Psi11: we apply it to matrices and vectors alike.
|
|
149
|
+
Psi11_inv = solve_triangular(Psi11, jnp.eye(ny), lower=True)
|
|
150
|
+
|
|
151
|
+
# predictive model given one observation
|
|
152
|
+
K = Psi21 @ Psi11_inv # local Kalman gain
|
|
153
|
+
HF = H @ F # temporary variable
|
|
154
|
+
A = F - K @ HF # corrected transition matrix
|
|
155
|
+
|
|
156
|
+
b = c + K @ (y - H @ c - d) # corrected transition offset
|
|
157
|
+
|
|
158
|
+
# information filter
|
|
159
|
+
Z = HF.T @ Psi11_inv.T
|
|
160
|
+
eta = Psi11_inv @ (y - H @ c - d)
|
|
161
|
+
eta = Z @ eta
|
|
162
|
+
|
|
163
|
+
if nx > ny:
|
|
164
|
+
Z = jnp.concatenate([Z, jnp.zeros((nx, nx - ny))], axis=1)
|
|
165
|
+
else:
|
|
166
|
+
Z = tria(Z)
|
|
167
|
+
|
|
168
|
+
# local log marginal likelihood
|
|
169
|
+
ell = jnp.asarray(
|
|
170
|
+
multivariate_normal.logpdf(y, H @ c + d, Psi11, nan_support=False)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return FilterScanElement(A, b, U, eta, Z, ell)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def filtering_operator(
|
|
177
|
+
elem_i: FilterScanElement, elem_j: FilterScanElement
|
|
178
|
+
) -> FilterScanElement:
|
|
179
|
+
"""Binary associative operator for the square root Kalman filter.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
elem_i: Filter scan element for the previous time step.
|
|
183
|
+
elem_j: Filter scan element for the current time step.
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
FilterScanElement: The output of the associative operator applied to the input elements.
|
|
187
|
+
"""
|
|
188
|
+
A1, b1, U1, eta1, Z1, ell1 = elem_i
|
|
189
|
+
A2, b2, U2, eta2, Z2, ell2 = elem_j
|
|
190
|
+
|
|
191
|
+
nx = Z2.shape[0]
|
|
192
|
+
|
|
193
|
+
Xi = jnp.block([[U1.T @ Z2, jnp.eye(nx)], [Z2, jnp.zeros_like(A1)]])
|
|
194
|
+
tria_xi = tria(Xi)
|
|
195
|
+
Xi11 = tria_xi[:nx, :nx]
|
|
196
|
+
Xi21 = tria_xi[nx : nx + nx, :nx]
|
|
197
|
+
Xi22 = tria_xi[nx : nx + nx, nx:]
|
|
198
|
+
|
|
199
|
+
tmp_1 = solve_triangular(Xi11, U1.T, lower=True).T
|
|
200
|
+
D_inv = jnp.eye(nx) - tmp_1 @ Xi21.T
|
|
201
|
+
tmp_2 = D_inv @ (b1 + U1 @ (U1.T @ eta2))
|
|
202
|
+
|
|
203
|
+
A = A2 @ D_inv @ A1
|
|
204
|
+
b = A2 @ tmp_2 + b2
|
|
205
|
+
U = tria(jnp.concatenate([A2 @ tmp_1, U2], axis=1))
|
|
206
|
+
eta = A1.T @ (D_inv.T @ (eta2 - Z2 @ (Z2.T @ b1))) + eta1
|
|
207
|
+
Z = tria(jnp.concatenate([A1.T @ Xi22, Z1], axis=1))
|
|
208
|
+
|
|
209
|
+
mu = cho_solve((U1, True), b1)
|
|
210
|
+
t1 = b1 @ mu - (eta2 + mu) @ tmp_2
|
|
211
|
+
ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]
|
|
212
|
+
|
|
213
|
+
return FilterScanElement(A, b, U, eta, Z, ell)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Utilities to generate linear-Gaussian state-space models (LGSSMs)."""
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from jax import Array, random
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.types import KeyArray
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@partial(jax.jit, static_argnames=("x_dim", "y_dim", "num_time_steps"))
|
|
13
|
+
def generate_lgssm(seed: int, x_dim: int, y_dim: int, num_time_steps: int):
|
|
14
|
+
"""Generates an LGSSM along with a set of observations."""
|
|
15
|
+
key = random.key(seed)
|
|
16
|
+
|
|
17
|
+
key, init_key, sample_key, obs_model_key, obs_key = random.split(key, 5)
|
|
18
|
+
m0, chol_P0 = generate_init_model(init_key, x_dim)
|
|
19
|
+
x0 = m0 + chol_P0 @ random.normal(sample_key, (x_dim,))
|
|
20
|
+
|
|
21
|
+
# Generate an observation for time 0
|
|
22
|
+
H0, d0, chol_R0 = generate_obs_model(obs_model_key, x_dim, y_dim)
|
|
23
|
+
obs_noise = chol_R0 @ random.normal(obs_key, (y_dim,))
|
|
24
|
+
y0 = H0 @ x0 + d0 + obs_noise
|
|
25
|
+
|
|
26
|
+
def body(_x, _key):
|
|
27
|
+
trans_model_key, trans_key, obs_model_key, obs_key = random.split(_key, 4)
|
|
28
|
+
|
|
29
|
+
F, c, chol_Q = generate_trans_model(trans_model_key, x_dim)
|
|
30
|
+
state_noise = chol_Q @ random.normal(trans_key, (x_dim,))
|
|
31
|
+
x = F @ _x + c + state_noise
|
|
32
|
+
|
|
33
|
+
H, d, chol_R = generate_obs_model(obs_model_key, x_dim, y_dim)
|
|
34
|
+
obs_noise = chol_R @ random.normal(obs_key, (y_dim,))
|
|
35
|
+
y = H @ x + d + obs_noise
|
|
36
|
+
|
|
37
|
+
return x, (F, c, chol_Q, H, d, chol_R, y)
|
|
38
|
+
|
|
39
|
+
scan_keys = random.split(key, num_time_steps)
|
|
40
|
+
_, (Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys) = jax.lax.scan(body, x0, scan_keys)
|
|
41
|
+
|
|
42
|
+
# Prepend the observation model parameters and observation for t=0
|
|
43
|
+
Hs, ds, chol_Rs, ys = jax.tree.map(
|
|
44
|
+
lambda x, xs: jnp.concatenate([x[None], xs]),
|
|
45
|
+
(H0, d0, chol_R0, y0),
|
|
46
|
+
(Hs, ds, chol_Rs, ys),
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return m0, chol_P0, Fs, cs, chol_Qs, Hs, ds, chol_Rs, ys
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def generate_cholesky_factor(key: KeyArray, dim: int) -> Array:
|
|
53
|
+
"""Generates a random Cholesky factor (lower-triangular matrix)."""
|
|
54
|
+
chol_A = random.uniform(key, (dim, dim))
|
|
55
|
+
chol_A = chol_A.at[jnp.triu_indices(dim, 1)].set(0.0)
|
|
56
|
+
return chol_A
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def generate_init_model(key: KeyArray, x_dim: int) -> tuple[Array, Array]:
|
|
60
|
+
"""Generates a random initial state for an LGSSM."""
|
|
61
|
+
keys = random.split(key)
|
|
62
|
+
m0 = random.normal(keys[0], (x_dim,))
|
|
63
|
+
chol_P0 = generate_cholesky_factor(keys[1], x_dim)
|
|
64
|
+
return m0, chol_P0
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def generate_trans_model(key: KeyArray, x_dim: int) -> tuple[Array, Array, Array]:
|
|
68
|
+
"""Generates a random transition model for an LGSSM."""
|
|
69
|
+
keys = random.split(key, 3)
|
|
70
|
+
exp_eig_max = 0.75 # Chosen less than one to stop exploding states (in expectation)
|
|
71
|
+
F = exp_eig_max * random.normal(keys[0], (x_dim, x_dim)) / jnp.sqrt(x_dim)
|
|
72
|
+
c = 0.1 * random.normal(keys[1], (x_dim,))
|
|
73
|
+
chol_Q = generate_cholesky_factor(keys[2], x_dim)
|
|
74
|
+
return F, c, chol_Q
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def generate_obs_model(
|
|
78
|
+
key: KeyArray, x_dim: int, y_dim: int
|
|
79
|
+
) -> tuple[Array, Array, Array]:
|
|
80
|
+
"""Generates a random observation model for an LGSSM."""
|
|
81
|
+
keys = random.split(key, 3)
|
|
82
|
+
H = random.normal(keys[0], (y_dim, x_dim))
|
|
83
|
+
d = random.normal(keys[1], (y_dim,))
|
|
84
|
+
chol_R = generate_cholesky_factor(keys[2], y_dim)
|
|
85
|
+
return H, d, chol_R
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Implements the square root parallel Kalman associative operator for sampling.
|
|
2
|
+
|
|
3
|
+
Samples from the smoothing distribution without doing the smoothing scan for means
|
|
4
|
+
and (chol) covariances.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import NamedTuple, Sequence
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
|
|
12
|
+
from cuthbertlib.kalman.smoothing import associative_params_single
|
|
13
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SamplerScanElement(NamedTuple):
|
|
17
|
+
"""Kalman sampling scan element."""
|
|
18
|
+
|
|
19
|
+
gain: Array
|
|
20
|
+
sample: Array
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def sqrt_associative_params(
|
|
24
|
+
key: ArrayLike,
|
|
25
|
+
ms: Array,
|
|
26
|
+
chol_Ps: Array,
|
|
27
|
+
Fs: Array,
|
|
28
|
+
cs: Array,
|
|
29
|
+
chol_Qs: Array,
|
|
30
|
+
shape: Sequence[int],
|
|
31
|
+
) -> SamplerScanElement:
|
|
32
|
+
"""Compute the sampler scan elements."""
|
|
33
|
+
shape = tuple(shape)
|
|
34
|
+
eps = jax.random.normal(key, ms.shape[:1] + shape + ms.shape[1:])
|
|
35
|
+
interm_elems = jax.vmap(_sqrt_associative_params_interm)(
|
|
36
|
+
ms[:-1], chol_Ps[:-1], Fs, cs, chol_Qs, eps[:-1]
|
|
37
|
+
)
|
|
38
|
+
last_elem = _sqrt_associative_params_final(ms[-1], chol_Ps[-1], eps[-1])
|
|
39
|
+
return jax.tree.map(
|
|
40
|
+
lambda x, y: jnp.concatenate([x, y[None]]), interm_elems, last_elem
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _sqrt_associative_params_interm(
|
|
45
|
+
m: Array, chol_P: Array, F: Array, c: Array, chol_Q: Array, eps: Array
|
|
46
|
+
) -> SamplerScanElement:
|
|
47
|
+
inc_m, gain, L = associative_params_single(m, chol_P, F, c, chol_Q)
|
|
48
|
+
inc = inc_m + eps @ L.T
|
|
49
|
+
return SamplerScanElement(gain, inc)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _sqrt_associative_params_final(
|
|
53
|
+
m: Array, chol_P: Array, eps: Array
|
|
54
|
+
) -> SamplerScanElement:
|
|
55
|
+
gain = jnp.zeros_like(chol_P)
|
|
56
|
+
sample = m + eps @ chol_P.T
|
|
57
|
+
return SamplerScanElement(gain, sample)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def sampling_operator(
|
|
61
|
+
elem_i: SamplerScanElement, elem_j: SamplerScanElement
|
|
62
|
+
) -> SamplerScanElement:
|
|
63
|
+
"""Binary associative operator for sampling."""
|
|
64
|
+
G_i, e_i = elem_i
|
|
65
|
+
G_j, e_j = elem_j
|
|
66
|
+
G = G_j @ G_i
|
|
67
|
+
e = e_i @ G_j.T + e_j
|
|
68
|
+
return SamplerScanElement(G, e)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""Implements the square root Rauch–Tung–Striebel (RTS) smoother and associative variant."""
|
|
2
|
+
|
|
3
|
+
from typing import NamedTuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from jax.scipy.linalg import solve_triangular
|
|
8
|
+
|
|
9
|
+
from cuthbertlib.linalg import tria
|
|
10
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SmootherScanElement(NamedTuple):
|
|
14
|
+
"""Kalman smoother scan element."""
|
|
15
|
+
|
|
16
|
+
g: Array
|
|
17
|
+
E: Array
|
|
18
|
+
D: Array
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Note that `update` is aliased as `kalman.smoother_update` in `kalman.__init__.py`
|
|
22
|
+
def update(
|
|
23
|
+
filter_m: ArrayLike,
|
|
24
|
+
filter_chol_P: ArrayLike,
|
|
25
|
+
smoother_m: ArrayLike,
|
|
26
|
+
smoother_chol_P: ArrayLike,
|
|
27
|
+
F: ArrayLike,
|
|
28
|
+
c: ArrayLike,
|
|
29
|
+
chol_Q: ArrayLike,
|
|
30
|
+
) -> tuple[tuple[Array, Array], Array]:
|
|
31
|
+
"""Single step of the square root Rauch–Tung–Striebel (RTS) smoother.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
filter_m: Mean of the filtered state.
|
|
35
|
+
filter_chol_P: Generalized Cholesky factor of the filtering covariance.
|
|
36
|
+
smoother_m: Mean of the smoother state.
|
|
37
|
+
smoother_chol_P: Generalized Cholesky factor of the smoothing covariance.
|
|
38
|
+
F: State transition matrix.
|
|
39
|
+
c: State transition shift vector.
|
|
40
|
+
chol_Q: Generalized Cholesky factor of the state transition noise covariance.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A tuple `(smooth_state, info)`.
|
|
44
|
+
`smooth_state` contains the smoothed mean and square root covariance.
|
|
45
|
+
`info` contains the smoothing gain matrix.
|
|
46
|
+
|
|
47
|
+
References:
|
|
48
|
+
Paper: Park and Kailath (1994) - Square-root RTS smoothing algorithms
|
|
49
|
+
Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
|
|
50
|
+
"""
|
|
51
|
+
filter_m, filter_chol_P = jnp.asarray(filter_m), jnp.asarray(filter_chol_P)
|
|
52
|
+
smoother_m, smoother_chol_P = jnp.asarray(smoother_m), jnp.asarray(smoother_chol_P)
|
|
53
|
+
F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)
|
|
54
|
+
|
|
55
|
+
nx = F.shape[0]
|
|
56
|
+
Phi = jnp.block([[F @ filter_chol_P, chol_Q], [filter_chol_P, jnp.zeros_like(F)]])
|
|
57
|
+
tria_Phi = tria(Phi)
|
|
58
|
+
Phi11 = tria_Phi[:nx, :nx]
|
|
59
|
+
Phi21 = tria_Phi[nx:, :nx]
|
|
60
|
+
Phi22 = tria_Phi[nx:, nx:]
|
|
61
|
+
gain = solve_triangular(Phi11, Phi21.T, trans=True, lower=True).T
|
|
62
|
+
|
|
63
|
+
mean_diff = smoother_m - (c + F @ filter_m)
|
|
64
|
+
mean = filter_m + gain @ mean_diff
|
|
65
|
+
chol = tria(jnp.concatenate([Phi22, gain @ smoother_chol_P], axis=1))
|
|
66
|
+
return (mean, chol), gain
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def associative_params_single(
|
|
70
|
+
m: Array,
|
|
71
|
+
chol_P: Array,
|
|
72
|
+
F: Array,
|
|
73
|
+
c: Array,
|
|
74
|
+
chol_Q: Array,
|
|
75
|
+
) -> SmootherScanElement:
|
|
76
|
+
"""Single time step for scan element for square root parallel Kalman smoother.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
m: Mean of the smoother state.
|
|
80
|
+
chol_P: Generalized Cholesky factor of the smoothing covariance.
|
|
81
|
+
F: State transition matrix.
|
|
82
|
+
c: State transition shift vector.
|
|
83
|
+
chol_Q: Generalized Cholesky factor of the state transition noise covariance.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
SmootherScanElement: The output of the associative operator applied to the input
|
|
87
|
+
elements.
|
|
88
|
+
"""
|
|
89
|
+
nx = chol_Q.shape[0]
|
|
90
|
+
|
|
91
|
+
Phi = jnp.block([[F @ chol_P, chol_Q], [chol_P, jnp.zeros_like(chol_Q)]])
|
|
92
|
+
Tria_Phi = tria(Phi)
|
|
93
|
+
Phi11 = Tria_Phi[:nx, :nx]
|
|
94
|
+
Phi21 = Tria_Phi[nx:, :nx]
|
|
95
|
+
D = Tria_Phi[nx:, nx:]
|
|
96
|
+
|
|
97
|
+
E = jax.scipy.linalg.solve_triangular(Phi11.T, Phi21.T).T
|
|
98
|
+
g = m - E @ (F @ m + c)
|
|
99
|
+
return SmootherScanElement(g, E, D)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def smoothing_operator(
|
|
103
|
+
elem_i: SmootherScanElement, elem_j: SmootherScanElement
|
|
104
|
+
) -> SmootherScanElement:
|
|
105
|
+
"""Binary associative operator for the square root Kalman smoother.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
elem_i: Smoother scan element.
|
|
109
|
+
elem_j: Smoother scan element.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
SmootherScanElement: The output of the associative operator applied to the input elements.
|
|
113
|
+
"""
|
|
114
|
+
g_i, E_i, D_i = elem_i
|
|
115
|
+
g_j, E_j, D_j = elem_j
|
|
116
|
+
|
|
117
|
+
g = E_j @ g_i + g_j
|
|
118
|
+
E = E_j @ E_i
|
|
119
|
+
D = tria(jnp.concatenate([E_j @ D_i, D_j], axis=1))
|
|
120
|
+
|
|
121
|
+
return SmootherScanElement(g, E, D)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from cuthbertlib.linalg.collect_nans_chol import collect_nans_chol
|
|
2
|
+
from cuthbertlib.linalg.marginal_sqrt_cov import marginal_sqrt_cov
|
|
3
|
+
from cuthbertlib.linalg.symmetric_inv_sqrt import (
|
|
4
|
+
chol_cov_with_nans_to_cov,
|
|
5
|
+
symmetric_inv_sqrt,
|
|
6
|
+
)
|
|
7
|
+
from cuthbertlib.linalg.tria import tria
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Implements collection of NaNs and reordering within a Cholesky factor."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
from jax import tree
|
|
7
|
+
|
|
8
|
+
from cuthbertlib.linalg.tria import tria
|
|
9
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _set_to_zero(flag: ArrayLike, x: ArrayLike) -> Array:
|
|
13
|
+
x = jnp.asarray(x)
|
|
14
|
+
broadcast_flag = jnp.expand_dims(flag, list(range(1, x.ndim)))
|
|
15
|
+
return jnp.where(broadcast_flag, 0.0, x)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def collect_nans_chol(flag: ArrayLike, chol: ArrayLike, *rest: Any) -> Any:
|
|
19
|
+
"""Converts chol into an order chol factor with NaNs moved to the bottom right.
|
|
20
|
+
|
|
21
|
+
Specifically, converts a generalized Cholesky factor of a covariance matrix wit
|
|
22
|
+
NaNs into an ordered generalized Cholesky factor with NaNs rows and columns
|
|
23
|
+
moved to the end with diagonal elements set to 1.
|
|
24
|
+
|
|
25
|
+
Also reorders the rest of the arguments in the same way along the first axis
|
|
26
|
+
and sets to 0 for dimensions where flag is True.
|
|
27
|
+
|
|
28
|
+
Example behavior:
|
|
29
|
+
```
|
|
30
|
+
flag = jnp.array([False, True, False, True])
|
|
31
|
+
new_flag, new_chol, new_mean = collect_nans_chol(flag, chol, mean)
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
flag: Array, boolean array indicating which entries are NaN
|
|
36
|
+
True for NaN entries, False for valid
|
|
37
|
+
chol: Array, Cholesky factor of the covariance matrix
|
|
38
|
+
rest: Any, rest of the arguments to be reordered in the same way
|
|
39
|
+
along the first axis
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
flag, chol and rest reordered so that valid entries are first and NaNs are last.
|
|
43
|
+
Diagonal elements of chol are set to 1/√2π so that normalization is correct
|
|
44
|
+
"""
|
|
45
|
+
flag = jnp.asarray(flag)
|
|
46
|
+
chol = jnp.asarray(chol)
|
|
47
|
+
|
|
48
|
+
# TODO: Can we support batching? I.e. when `chol` is a batch of Cholesky factors,
|
|
49
|
+
# possibly with multiple leading dimensions
|
|
50
|
+
|
|
51
|
+
if flag.ndim > 1 or chol.ndim > 2:
|
|
52
|
+
raise ValueError("Batched flag or chol not supported yet")
|
|
53
|
+
|
|
54
|
+
if not flag.shape:
|
|
55
|
+
return (
|
|
56
|
+
flag,
|
|
57
|
+
_set_to_zero(flag, chol),
|
|
58
|
+
*tree.map(lambda x: _set_to_zero(flag, x), rest),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if chol.size == 1:
|
|
62
|
+
chol *= jnp.ones_like(flag, dtype=chol.dtype)
|
|
63
|
+
|
|
64
|
+
# group the NaN entries together
|
|
65
|
+
argsort = jnp.argsort(flag, stable=True)
|
|
66
|
+
|
|
67
|
+
if chol.ndim == 1:
|
|
68
|
+
chol = chol[argsort]
|
|
69
|
+
flag = flag[argsort]
|
|
70
|
+
chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), chol)
|
|
71
|
+
|
|
72
|
+
else:
|
|
73
|
+
chol = jnp.where(flag[:, None], 0.0, chol)
|
|
74
|
+
chol = chol[argsort]
|
|
75
|
+
# compute the tria of the covariance matrix with NaNs set to 0
|
|
76
|
+
chol = tria(chol)
|
|
77
|
+
|
|
78
|
+
flag = flag[argsort]
|
|
79
|
+
|
|
80
|
+
# set the diagonal of chol_cov to 1/√2π where nans were present so that normalization is correct
|
|
81
|
+
diag_chol = jnp.diag(chol)
|
|
82
|
+
diag_chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), diag_chol)
|
|
83
|
+
diag_indices = jnp.diag_indices_from(chol)
|
|
84
|
+
chol = chol.at[diag_indices].set(diag_chol)
|
|
85
|
+
|
|
86
|
+
# Only reorder non-scalar arrays in rest
|
|
87
|
+
rest = tree.map(lambda x: x[argsort] if jnp.asarray(x).shape else x, rest)
|
|
88
|
+
rest = tree.map(lambda x: _set_to_zero(flag, x), rest)
|
|
89
|
+
|
|
90
|
+
return flag, chol, *rest
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Extract marginal square root covariance from a joint square root covariance."""
|
|
2
|
+
|
|
3
|
+
from typing import Sequence
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.linalg.tria import tria
|
|
8
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def marginal_sqrt_cov(chol_cov: ArrayLike, start: int, end: int) -> Array:
|
|
12
|
+
"""Extracts square root submatrix from a joint square root matrix.
|
|
13
|
+
|
|
14
|
+
Specifically, returns B such that
|
|
15
|
+
B @ B.T = (chol_cov @ chol_cov.T)[start:end, start:end]
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
chol_cov: Generalized Cholesky factor of the covariance matrix.
|
|
19
|
+
start: Start index of the submatrix.
|
|
20
|
+
end: End index of the submatrix.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Lower triangular square root matrix of the marginal covariance matrix.
|
|
24
|
+
"""
|
|
25
|
+
chol_cov = jnp.asarray(chol_cov)
|
|
26
|
+
assert chol_cov.ndim == 2, "chol_cov must be a 2D array"
|
|
27
|
+
assert chol_cov.shape[0] == chol_cov.shape[1], "chol_cov must be square"
|
|
28
|
+
assert start >= 0 and end <= chol_cov.shape[0], (
|
|
29
|
+
"start and end must be within the bounds of chol_cov"
|
|
30
|
+
)
|
|
31
|
+
assert start < end, "start must be less than end"
|
|
32
|
+
|
|
33
|
+
chol_cov_select_rows = chol_cov[start:end, :]
|
|
34
|
+
return tria(chol_cov_select_rows)
|