cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.1.dist-info/RECORD +0 -12
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Implements the associative linearized Taylor Kalman filter."""
|
|
2
|
+
|
|
3
|
+
from jax import eval_shape, tree
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbert.gaussian.taylor.non_associative_filter import process_observation
|
|
7
|
+
from cuthbert.gaussian.taylor.types import (
|
|
8
|
+
GetDynamicsLogDensity,
|
|
9
|
+
GetInitLogDensity,
|
|
10
|
+
GetObservationFunc,
|
|
11
|
+
)
|
|
12
|
+
from cuthbert.gaussian.types import (
|
|
13
|
+
LinearizedKalmanFilterState,
|
|
14
|
+
)
|
|
15
|
+
from cuthbert.utils import dummy_tree_like
|
|
16
|
+
from cuthbertlib.kalman import filtering
|
|
17
|
+
from cuthbertlib.linearize import linearize_log_density
|
|
18
|
+
from cuthbertlib.types import (
|
|
19
|
+
ArrayTreeLike,
|
|
20
|
+
KeyArray,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def init_prepare(
|
|
25
|
+
model_inputs: ArrayTreeLike,
|
|
26
|
+
get_init_log_density: GetInitLogDensity,
|
|
27
|
+
get_observation_func: GetObservationFunc,
|
|
28
|
+
rtol: float | None = None,
|
|
29
|
+
ignore_nan_dims: bool = False,
|
|
30
|
+
key: KeyArray | None = None,
|
|
31
|
+
) -> LinearizedKalmanFilterState:
|
|
32
|
+
"""Prepare the initial state for the linearized Taylor Kalman filter.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model_inputs: Model inputs.
|
|
36
|
+
get_init_log_density: Function that returns log density log p(x_0)
|
|
37
|
+
and linearization point.
|
|
38
|
+
get_observation_func: Function that returns either
|
|
39
|
+
- An observation log density
|
|
40
|
+
function log p(y_0 | x_0) as well as points x_0 and y_0
|
|
41
|
+
to linearize around.
|
|
42
|
+
- A log potential function log G(x_0) and a linearization point x_0.
|
|
43
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
44
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
45
|
+
Cutoff for small singular values; singular values smaller than
|
|
46
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
47
|
+
The default is determined based on the floating point precision of the dtype.
|
|
48
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
49
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
50
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
51
|
+
and columns associated with them.
|
|
52
|
+
key: JAX random key - not used.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
State for the linearized Taylor Kalman filter.
|
|
56
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
57
|
+
and log_normalizing_constant.
|
|
58
|
+
"""
|
|
59
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
60
|
+
init_log_density, linearization_point = get_init_log_density(model_inputs)
|
|
61
|
+
|
|
62
|
+
_, m0, chol_P0 = linearize_log_density(
|
|
63
|
+
lambda _, x: init_log_density(x),
|
|
64
|
+
linearization_point,
|
|
65
|
+
linearization_point,
|
|
66
|
+
rtol=rtol,
|
|
67
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
prior_state = LinearizedKalmanFilterState(
|
|
71
|
+
elem=filtering.FilterScanElement(
|
|
72
|
+
A=jnp.zeros_like(chol_P0),
|
|
73
|
+
b=m0,
|
|
74
|
+
U=chol_P0,
|
|
75
|
+
eta=jnp.zeros_like(m0),
|
|
76
|
+
Z=jnp.zeros_like(chol_P0),
|
|
77
|
+
ell=jnp.array(0.0),
|
|
78
|
+
),
|
|
79
|
+
model_inputs=model_inputs,
|
|
80
|
+
mean_prev=dummy_tree_like(m0),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
observation_output = get_observation_func(prior_state, model_inputs)
|
|
84
|
+
H, d, chol_R, observation = process_observation(
|
|
85
|
+
observation_output,
|
|
86
|
+
rtol=rtol,
|
|
87
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, observation)
|
|
91
|
+
|
|
92
|
+
elem = filtering.FilterScanElement(
|
|
93
|
+
A=jnp.zeros_like(chol_P),
|
|
94
|
+
b=m,
|
|
95
|
+
U=chol_P,
|
|
96
|
+
eta=jnp.zeros_like(m),
|
|
97
|
+
Z=jnp.zeros_like(chol_P),
|
|
98
|
+
ell=ell,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return LinearizedKalmanFilterState(
|
|
102
|
+
elem=elem,
|
|
103
|
+
model_inputs=model_inputs,
|
|
104
|
+
mean_prev=dummy_tree_like(m),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def filter_prepare(
|
|
109
|
+
model_inputs: ArrayTreeLike,
|
|
110
|
+
get_init_log_density: GetInitLogDensity,
|
|
111
|
+
get_dynamics_log_density: GetDynamicsLogDensity,
|
|
112
|
+
get_observation_func: GetObservationFunc,
|
|
113
|
+
rtol: float | None = None,
|
|
114
|
+
ignore_nan_dims: bool = False,
|
|
115
|
+
key: KeyArray | None = None,
|
|
116
|
+
) -> LinearizedKalmanFilterState:
|
|
117
|
+
"""Prepare a state for a linearized Taylor Kalman filter step.
|
|
118
|
+
|
|
119
|
+
`associative_scan` is supported but only accurate when `state` is ignored
|
|
120
|
+
in `get_dynamics_log_density` and `get_observation_func`.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
model_inputs: Model inputs.
|
|
124
|
+
get_init_log_density: Function that returns log density log p(x_0)
|
|
125
|
+
and linearization point. Only used to infer shape of mean and chol_cov.
|
|
126
|
+
get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
|
|
127
|
+
and linearization points (for the previous and current time points)
|
|
128
|
+
`associative_scan` only supported when `state` is ignored.
|
|
129
|
+
get_observation_func: Function to get observation function (either conditional
|
|
130
|
+
log density or log potential), linearization point and optional observation
|
|
131
|
+
(not required for log potential functions).
|
|
132
|
+
`associative_scan` only supported when `state` is ignored.
|
|
133
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
134
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
135
|
+
Cutoff for small singular values; singular values smaller than
|
|
136
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
137
|
+
The default is determined based on the floating point precision of the dtype.
|
|
138
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
139
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
140
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
141
|
+
and columns associated with them.
|
|
142
|
+
key: JAX random key - not used.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Prepared state for linearized Taylor Kalman filter.
|
|
146
|
+
"""
|
|
147
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
148
|
+
dummy_mean_struct = eval_shape(lambda mi: get_init_log_density(mi)[1], model_inputs)
|
|
149
|
+
dummy_mean = dummy_tree_like(dummy_mean_struct)
|
|
150
|
+
dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
|
|
151
|
+
|
|
152
|
+
dummy_state = LinearizedKalmanFilterState(
|
|
153
|
+
elem=filtering.FilterScanElement(
|
|
154
|
+
A=dummy_chol_cov,
|
|
155
|
+
b=dummy_mean,
|
|
156
|
+
U=dummy_chol_cov,
|
|
157
|
+
eta=dummy_mean,
|
|
158
|
+
Z=dummy_chol_cov,
|
|
159
|
+
ell=jnp.array(0.0),
|
|
160
|
+
),
|
|
161
|
+
model_inputs=model_inputs,
|
|
162
|
+
mean_prev=dummy_mean,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
log_dynamics_density, linearization_point_prev, linearization_point_curr = (
|
|
166
|
+
get_dynamics_log_density(dummy_state, model_inputs)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
F, c, chol_Q = linearize_log_density(
|
|
170
|
+
log_dynamics_density,
|
|
171
|
+
linearization_point_prev,
|
|
172
|
+
linearization_point_curr,
|
|
173
|
+
rtol=rtol,
|
|
174
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
observation_output = get_observation_func(dummy_state, model_inputs)
|
|
178
|
+
H, d, chol_R, observation = process_observation(
|
|
179
|
+
observation_output,
|
|
180
|
+
rtol=rtol,
|
|
181
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, observation)
|
|
185
|
+
|
|
186
|
+
return LinearizedKalmanFilterState(
|
|
187
|
+
elem=elem,
|
|
188
|
+
model_inputs=model_inputs,
|
|
189
|
+
mean_prev=dummy_mean,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def filter_combine(
|
|
194
|
+
state_1: LinearizedKalmanFilterState,
|
|
195
|
+
state_2: LinearizedKalmanFilterState,
|
|
196
|
+
) -> LinearizedKalmanFilterState:
|
|
197
|
+
"""Combine previous filter state with state prepared with latest model inputs.
|
|
198
|
+
|
|
199
|
+
`associative_scan` is supported but only accurate when `state` is ignored
|
|
200
|
+
in `get_dynamics_log_density` and `get_observation_func`.
|
|
201
|
+
|
|
202
|
+
Applies standard associative Kalman filtering operator since dynamics and observation
|
|
203
|
+
parameters are extracted in filter_prepare.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
state_1: State from previous time step.
|
|
207
|
+
state_2: State prepared (only access model_inputs attribute).
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Predicted and updated linearized Taylor Kalman filter state.
|
|
211
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
212
|
+
and log_normalizing_constant.
|
|
213
|
+
"""
|
|
214
|
+
combined_elem = filtering.filtering_operator(
|
|
215
|
+
state_1.elem,
|
|
216
|
+
state_2.elem,
|
|
217
|
+
)
|
|
218
|
+
return LinearizedKalmanFilterState(
|
|
219
|
+
elem=combined_elem,
|
|
220
|
+
model_inputs=state_2.model_inputs,
|
|
221
|
+
mean_prev=state_1.mean,
|
|
222
|
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
r"""Linearized Taylor Kalman filter.
|
|
2
|
+
|
|
3
|
+
Uses automatic differentiation to extract conditionally Gaussian parameters from log
|
|
4
|
+
densities of the dynamics and observation distributions.
|
|
5
|
+
|
|
6
|
+
This differs from `gaussian/moments`, which requires `mean` and `chol_cov`
|
|
7
|
+
functions as input rather than log densities.
|
|
8
|
+
|
|
9
|
+
I.e., we approximate conditional densities as
|
|
10
|
+
|
|
11
|
+
$$
|
|
12
|
+
p(y \mid x) \approx N(y \mid H x + d, L L^T),
|
|
13
|
+
$$
|
|
14
|
+
|
|
15
|
+
and potentials as
|
|
16
|
+
|
|
17
|
+
$$
|
|
18
|
+
G(x) \approx N(x \mid m, L L^T),
|
|
19
|
+
$$
|
|
20
|
+
|
|
21
|
+
where $L$ is the Cholesky factor of the covariance matrix.
|
|
22
|
+
|
|
23
|
+
See `cuthbertlib.linearize` for more details.
|
|
24
|
+
|
|
25
|
+
Parallelism via `associative_scan` is supported, but requires the `state` argument
|
|
26
|
+
to be ignored in `get_dynamics_log_density` and `get_observation_func`.
|
|
27
|
+
I.e. the linearization points are pre-defined or extracted from model inputs.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
from functools import partial
|
|
31
|
+
|
|
32
|
+
from cuthbert.gaussian.taylor import associative_filter, non_associative_filter
|
|
33
|
+
from cuthbert.gaussian.taylor.types import (
|
|
34
|
+
GetDynamicsLogDensity,
|
|
35
|
+
GetInitLogDensity,
|
|
36
|
+
GetObservationFunc,
|
|
37
|
+
)
|
|
38
|
+
from cuthbert.inference import Filter
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def build_filter(
|
|
42
|
+
get_init_log_density: GetInitLogDensity,
|
|
43
|
+
get_dynamics_log_density: GetDynamicsLogDensity,
|
|
44
|
+
get_observation_func: GetObservationFunc,
|
|
45
|
+
associative: bool = False,
|
|
46
|
+
rtol: float | None = None,
|
|
47
|
+
ignore_nan_dims: bool = False,
|
|
48
|
+
) -> Filter:
|
|
49
|
+
"""Build linearized Taylor Kalman inference filter.
|
|
50
|
+
|
|
51
|
+
If `associative` is True all filtering linearization points are pre-defined or
|
|
52
|
+
extracted from model inputs. The `state` argument should be ignored in
|
|
53
|
+
`get_dynamics_log_density` and `get_observation_func`.
|
|
54
|
+
|
|
55
|
+
If `associative` is False the linearization points can be extracted from the
|
|
56
|
+
previous filter state for dynamics parameters and the predict state for
|
|
57
|
+
observation parameters.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
get_init_log_density: Function to get log density log p(x_0)
|
|
61
|
+
and linearization point.
|
|
62
|
+
Only takes `model_inputs` as input.
|
|
63
|
+
get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
|
|
64
|
+
and linearization points (for the previous and current time points)
|
|
65
|
+
If `associative` is True, the `state` argument should be ignored.
|
|
66
|
+
get_observation_func: Function to get observation function (either conditional
|
|
67
|
+
log density or log potential), linearization point and optional observation
|
|
68
|
+
(not required for log potential functions).
|
|
69
|
+
If `associative` is True, the `state` argument should be ignored.
|
|
70
|
+
associative: If True, then the filter is suitable for associative scan, but
|
|
71
|
+
assumes that the `state` is ignored in `get_dynamics_log_density` and
|
|
72
|
+
`get_observation_func`.
|
|
73
|
+
If False, then the filter is suitable for non-associative scan, but
|
|
74
|
+
the user is free to use the `state` to extract the linearization points.
|
|
75
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
76
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
77
|
+
Cutoff for small singular values; singular values smaller than
|
|
78
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
79
|
+
The default is determined based on the floating point precision of the dtype.
|
|
80
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
81
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
82
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
83
|
+
and columns associated with them.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Linearized Taylor Kalman filter object.
|
|
87
|
+
"""
|
|
88
|
+
if associative:
|
|
89
|
+
return Filter(
|
|
90
|
+
init_prepare=partial(
|
|
91
|
+
associative_filter.init_prepare,
|
|
92
|
+
get_init_log_density=get_init_log_density,
|
|
93
|
+
get_observation_func=get_observation_func,
|
|
94
|
+
rtol=rtol,
|
|
95
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
96
|
+
),
|
|
97
|
+
filter_prepare=partial(
|
|
98
|
+
associative_filter.filter_prepare,
|
|
99
|
+
get_init_log_density=get_init_log_density,
|
|
100
|
+
get_dynamics_log_density=get_dynamics_log_density,
|
|
101
|
+
get_observation_func=get_observation_func,
|
|
102
|
+
rtol=rtol,
|
|
103
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
104
|
+
),
|
|
105
|
+
filter_combine=associative_filter.filter_combine,
|
|
106
|
+
associative=True,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
return Filter(
|
|
110
|
+
init_prepare=partial(
|
|
111
|
+
non_associative_filter.init_prepare,
|
|
112
|
+
get_init_log_density=get_init_log_density,
|
|
113
|
+
get_observation_func=get_observation_func,
|
|
114
|
+
rtol=rtol,
|
|
115
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
116
|
+
),
|
|
117
|
+
filter_prepare=partial(
|
|
118
|
+
non_associative_filter.filter_prepare,
|
|
119
|
+
get_init_log_density=get_init_log_density,
|
|
120
|
+
),
|
|
121
|
+
filter_combine=partial(
|
|
122
|
+
non_associative_filter.filter_combine,
|
|
123
|
+
get_dynamics_log_density=get_dynamics_log_density,
|
|
124
|
+
get_observation_func=get_observation_func,
|
|
125
|
+
rtol=rtol,
|
|
126
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
127
|
+
),
|
|
128
|
+
associative=False,
|
|
129
|
+
)
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""Implements the non-associative linearized Taylor Kalman filter."""
|
|
2
|
+
|
|
3
|
+
from jax import eval_shape, tree
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbert.gaussian.taylor.types import (
|
|
7
|
+
GetDynamicsLogDensity,
|
|
8
|
+
GetInitLogDensity,
|
|
9
|
+
GetObservationFunc,
|
|
10
|
+
LogConditionalDensity,
|
|
11
|
+
LogPotential,
|
|
12
|
+
)
|
|
13
|
+
from cuthbert.gaussian.types import LinearizedKalmanFilterState
|
|
14
|
+
from cuthbert.gaussian.utils import linearized_kalman_filter_state_dummy_elem
|
|
15
|
+
from cuthbert.utils import dummy_tree_like
|
|
16
|
+
from cuthbertlib.kalman import filtering
|
|
17
|
+
from cuthbertlib.linearize import linearize_log_density, linearize_taylor
|
|
18
|
+
from cuthbertlib.types import (
|
|
19
|
+
Array,
|
|
20
|
+
ArrayTreeLike,
|
|
21
|
+
KeyArray,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def process_observation(
|
|
26
|
+
observation_output: tuple[LogConditionalDensity, Array, Array]
|
|
27
|
+
| tuple[LogPotential, Array],
|
|
28
|
+
rtol: float | None = None,
|
|
29
|
+
ignore_nan_dims: bool = False,
|
|
30
|
+
) -> tuple[Array, Array, Array, Array]:
|
|
31
|
+
"""Process observation for linearized Taylor Kalman filter."""
|
|
32
|
+
if len(observation_output) == 3:
|
|
33
|
+
observation_cond_log_density, linearization_point, observation = (
|
|
34
|
+
observation_output
|
|
35
|
+
)
|
|
36
|
+
H, d, chol_R = linearize_log_density(
|
|
37
|
+
observation_cond_log_density,
|
|
38
|
+
linearization_point,
|
|
39
|
+
observation,
|
|
40
|
+
rtol=rtol,
|
|
41
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
observation_log_potential, linearization_point = observation_output
|
|
45
|
+
d, chol_R = linearize_taylor(
|
|
46
|
+
observation_log_potential,
|
|
47
|
+
linearization_point,
|
|
48
|
+
rtol=rtol,
|
|
49
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
50
|
+
)
|
|
51
|
+
# dummy mat and observation as potential is unconditional
|
|
52
|
+
# Note the minus sign as linear potential is -0.5 (x - d)^T (R R^T)^{-1} (x - d)
|
|
53
|
+
# and kalman expects -0.5 (y - H @ x - d)^T (R R^T)^{-1} (y - H @ x - d)
|
|
54
|
+
H = -jnp.eye(d.shape[0])
|
|
55
|
+
observation = jnp.where(
|
|
56
|
+
jnp.isnan(jnp.diag(chol_R)) * ignore_nan_dims, jnp.nan, 0.0
|
|
57
|
+
) # Tell the cuthbertlib.kalman to skip these dimensions
|
|
58
|
+
return H, d, chol_R, observation
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def init_prepare(
|
|
62
|
+
model_inputs: ArrayTreeLike,
|
|
63
|
+
get_init_log_density: GetInitLogDensity,
|
|
64
|
+
get_observation_func: GetObservationFunc,
|
|
65
|
+
rtol: float | None = None,
|
|
66
|
+
ignore_nan_dims: bool = False,
|
|
67
|
+
key: KeyArray | None = None,
|
|
68
|
+
) -> LinearizedKalmanFilterState:
|
|
69
|
+
"""Prepare the initial state for the linearized Taylor Kalman filter.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model_inputs: Model inputs.
|
|
73
|
+
get_init_log_density: Function that returns log density log p(x_0)
|
|
74
|
+
and linearization point.
|
|
75
|
+
get_observation_func: Function that returns either
|
|
76
|
+
- An observation log density
|
|
77
|
+
function log p(y_0 | x_0) as well as points x_0 and y_0
|
|
78
|
+
to linearize around.
|
|
79
|
+
- A log potential function log G(x_0) and a linearization point x_0.
|
|
80
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
81
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
82
|
+
Cutoff for small singular values; singular values smaller than
|
|
83
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
84
|
+
The default is determined based on the floating point precision of the dtype.
|
|
85
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
86
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
87
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
88
|
+
and columns associated with them.
|
|
89
|
+
key: JAX random key - not used.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
State for the linearized Taylor Kalman filter.
|
|
93
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
94
|
+
and log_normalizing_constant.
|
|
95
|
+
"""
|
|
96
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
97
|
+
init_log_density, linearization_point = get_init_log_density(model_inputs)
|
|
98
|
+
|
|
99
|
+
_, m0, chol_P0 = linearize_log_density(
|
|
100
|
+
lambda _, x: init_log_density(x),
|
|
101
|
+
linearization_point,
|
|
102
|
+
linearization_point,
|
|
103
|
+
rtol=rtol,
|
|
104
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
prior_state = linearized_kalman_filter_state_dummy_elem(
|
|
108
|
+
mean=m0,
|
|
109
|
+
chol_cov=chol_P0,
|
|
110
|
+
log_normalizing_constant=jnp.array(0.0),
|
|
111
|
+
model_inputs=model_inputs,
|
|
112
|
+
mean_prev=dummy_tree_like(m0),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
observation_output = get_observation_func(prior_state, model_inputs)
|
|
116
|
+
|
|
117
|
+
H, d, chol_R, observation = process_observation(
|
|
118
|
+
observation_output,
|
|
119
|
+
rtol=rtol,
|
|
120
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, observation)
|
|
124
|
+
|
|
125
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
126
|
+
mean=m,
|
|
127
|
+
chol_cov=chol_P,
|
|
128
|
+
log_normalizing_constant=ell,
|
|
129
|
+
model_inputs=model_inputs,
|
|
130
|
+
mean_prev=dummy_tree_like(m),
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def filter_prepare(
|
|
135
|
+
model_inputs: ArrayTreeLike,
|
|
136
|
+
get_init_log_density: GetInitLogDensity,
|
|
137
|
+
key: KeyArray | None = None,
|
|
138
|
+
) -> LinearizedKalmanFilterState:
|
|
139
|
+
"""Prepare a state for a linearized Taylor Kalman filter step.
|
|
140
|
+
|
|
141
|
+
Just passes through model inputs.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_inputs: Model inputs.
|
|
145
|
+
get_init_log_density: Function that returns log density log p(x_0)
|
|
146
|
+
and linearization point. Only used to infer shape of mean and chol_cov.
|
|
147
|
+
key: JAX random key - not used.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Prepared state for linearized Taylor Kalman filter.
|
|
151
|
+
"""
|
|
152
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
153
|
+
dummy_mean_struct = eval_shape(lambda mi: get_init_log_density(mi)[1], model_inputs)
|
|
154
|
+
dummy_mean = dummy_tree_like(dummy_mean_struct)
|
|
155
|
+
dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
|
|
156
|
+
|
|
157
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
158
|
+
mean=dummy_mean,
|
|
159
|
+
chol_cov=dummy_chol_cov,
|
|
160
|
+
log_normalizing_constant=jnp.array(0.0),
|
|
161
|
+
model_inputs=model_inputs,
|
|
162
|
+
mean_prev=dummy_mean,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def filter_combine(
|
|
167
|
+
state_1: LinearizedKalmanFilterState,
|
|
168
|
+
state_2: LinearizedKalmanFilterState,
|
|
169
|
+
get_dynamics_log_density: GetDynamicsLogDensity,
|
|
170
|
+
get_observation_func: GetObservationFunc,
|
|
171
|
+
rtol: float | None = None,
|
|
172
|
+
ignore_nan_dims: bool = False,
|
|
173
|
+
) -> LinearizedKalmanFilterState:
|
|
174
|
+
"""Combine previous filter state with state prepared from latest model inputs.
|
|
175
|
+
|
|
176
|
+
Applies linearized Taylor Kalman predict + filter update in covariance square
|
|
177
|
+
root form.
|
|
178
|
+
Not suitable for associative scan.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
state_1: State from previous time step.
|
|
182
|
+
state_2: State prepared (only access model_inputs attribute).
|
|
183
|
+
get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
|
|
184
|
+
and linearization points (for the previous and current time points)
|
|
185
|
+
get_observation_func: Function to get observation function (either conditional
|
|
186
|
+
log density or log potential), linearization point and optional observation
|
|
187
|
+
(not required for log potential functions).
|
|
188
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
189
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
190
|
+
Cutoff for small singular values; singular values smaller than
|
|
191
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
192
|
+
The default is determined based on the floating point precision of the dtype.
|
|
193
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
194
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
195
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
196
|
+
and columns associated with them.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Predicted and updated linearized Taylor Kalman filter state.
|
|
200
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
201
|
+
and log_normalizing_constant.
|
|
202
|
+
"""
|
|
203
|
+
log_dynamics_density, linearization_point_prev, linearization_point_curr = (
|
|
204
|
+
get_dynamics_log_density(state_1, state_2.model_inputs)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
F, c, chol_Q = linearize_log_density(
|
|
208
|
+
log_dynamics_density,
|
|
209
|
+
linearization_point_prev,
|
|
210
|
+
linearization_point_curr,
|
|
211
|
+
rtol=rtol,
|
|
212
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
predict_mean, predict_chol_cov = filtering.predict(
|
|
216
|
+
state_1.mean, state_1.chol_cov, F, c, chol_Q
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
predict_state = linearized_kalman_filter_state_dummy_elem(
|
|
220
|
+
mean=predict_mean,
|
|
221
|
+
chol_cov=predict_chol_cov,
|
|
222
|
+
log_normalizing_constant=state_1.log_normalizing_constant,
|
|
223
|
+
model_inputs=state_2.model_inputs,
|
|
224
|
+
mean_prev=state_1.mean,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
observation_output = get_observation_func(predict_state, state_2.model_inputs)
|
|
228
|
+
|
|
229
|
+
H, d, chol_R, observation = process_observation(
|
|
230
|
+
observation_output,
|
|
231
|
+
rtol=rtol,
|
|
232
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
(update_mean, update_chol_cov), log_normalizing_constant = filtering.update(
|
|
236
|
+
predict_mean, predict_chol_cov, H, d, chol_R, observation
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
240
|
+
mean=update_mean,
|
|
241
|
+
chol_cov=update_chol_cov,
|
|
242
|
+
log_normalizing_constant=state_1.log_normalizing_constant
|
|
243
|
+
+ log_normalizing_constant,
|
|
244
|
+
model_inputs=state_2.model_inputs,
|
|
245
|
+
mean_prev=state_1.mean,
|
|
246
|
+
)
|