cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.2.dist-info/RECORD +0 -12
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Implements the associative linearized moments Kalman filter."""
|
|
2
|
+
|
|
3
|
+
from jax import eval_shape, tree
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbert.gaussian.kalman import GetInitParams
|
|
7
|
+
from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
|
|
8
|
+
from cuthbert.gaussian.types import (
|
|
9
|
+
LinearizedKalmanFilterState,
|
|
10
|
+
)
|
|
11
|
+
from cuthbert.utils import dummy_tree_like
|
|
12
|
+
from cuthbertlib.kalman import filtering
|
|
13
|
+
from cuthbertlib.linearize import linearize_moments
|
|
14
|
+
from cuthbertlib.types import (
|
|
15
|
+
ArrayTreeLike,
|
|
16
|
+
KeyArray,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def init_prepare(
|
|
21
|
+
model_inputs: ArrayTreeLike,
|
|
22
|
+
get_init_params: GetInitParams,
|
|
23
|
+
get_observation_params: GetObservationMoments,
|
|
24
|
+
key: KeyArray | None = None,
|
|
25
|
+
) -> LinearizedKalmanFilterState:
|
|
26
|
+
"""Prepare the initial state for the linearized moments Kalman filter.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
model_inputs: Model inputs.
|
|
30
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
31
|
+
get_observation_params: Function to get observation conditional mean,
|
|
32
|
+
(generalised) Cholesky covariance function, linearization point and
|
|
33
|
+
observation.
|
|
34
|
+
key: JAX random key - not used.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
State for the linearized moments Kalman filter.
|
|
38
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
39
|
+
and log_normalizing_constant.
|
|
40
|
+
"""
|
|
41
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
42
|
+
m0, chol_P0 = get_init_params(model_inputs)
|
|
43
|
+
|
|
44
|
+
prior_state = LinearizedKalmanFilterState(
|
|
45
|
+
elem=filtering.FilterScanElement(
|
|
46
|
+
A=jnp.zeros_like(chol_P0),
|
|
47
|
+
b=m0,
|
|
48
|
+
U=chol_P0,
|
|
49
|
+
eta=jnp.zeros_like(m0),
|
|
50
|
+
Z=jnp.zeros_like(chol_P0),
|
|
51
|
+
ell=jnp.array(0.0),
|
|
52
|
+
),
|
|
53
|
+
model_inputs=model_inputs,
|
|
54
|
+
mean_prev=dummy_tree_like(m0),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
mean_and_chol_cov_func, linearization_point, y = get_observation_params(
|
|
58
|
+
prior_state, model_inputs
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
|
|
62
|
+
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
|
|
63
|
+
|
|
64
|
+
elem = filtering.FilterScanElement(
|
|
65
|
+
A=jnp.zeros_like(chol_P),
|
|
66
|
+
b=m,
|
|
67
|
+
U=chol_P,
|
|
68
|
+
eta=jnp.zeros_like(m),
|
|
69
|
+
Z=jnp.zeros_like(chol_P),
|
|
70
|
+
ell=ell,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return LinearizedKalmanFilterState(
|
|
74
|
+
elem=elem,
|
|
75
|
+
model_inputs=model_inputs,
|
|
76
|
+
mean_prev=dummy_tree_like(m),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def filter_prepare(
|
|
81
|
+
model_inputs: ArrayTreeLike,
|
|
82
|
+
get_init_params: GetInitParams,
|
|
83
|
+
get_dynamics_params: GetDynamicsMoments,
|
|
84
|
+
get_observation_params: GetObservationMoments,
|
|
85
|
+
key: KeyArray | None = None,
|
|
86
|
+
) -> LinearizedKalmanFilterState:
|
|
87
|
+
"""Prepare a state for a linearized moments Kalman filter step.
|
|
88
|
+
|
|
89
|
+
Just passes through model inputs.
|
|
90
|
+
|
|
91
|
+
`associative_scan` is supported but only accurate when `state` is ignored
|
|
92
|
+
in `get_dynamics_params` and `get_observation_params`.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model_inputs: Model inputs.
|
|
96
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
97
|
+
Only used to infer shape of mean and chol_cov.
|
|
98
|
+
get_dynamics_params: Function to get dynamics conditional mean and
|
|
99
|
+
(generalised) Cholesky covariance from linearization point and model inputs.
|
|
100
|
+
`associative_scan` only supported when `state` is ignored.
|
|
101
|
+
get_observation_params: Function to get observation conditional mean,
|
|
102
|
+
(generalised) Cholesky covariance and observation from linearization point
|
|
103
|
+
and model inputs.
|
|
104
|
+
`associative_scan` only supported when `state` is ignored.
|
|
105
|
+
key: JAX random key - not used.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Prepared state for linearized moments Kalman filter.
|
|
109
|
+
"""
|
|
110
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
111
|
+
dummy_mean_struct = eval_shape(lambda mi: get_init_params(mi)[0], model_inputs)
|
|
112
|
+
dummy_mean = dummy_tree_like(dummy_mean_struct)
|
|
113
|
+
dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
|
|
114
|
+
|
|
115
|
+
dummy_state = LinearizedKalmanFilterState(
|
|
116
|
+
elem=filtering.FilterScanElement(
|
|
117
|
+
A=dummy_chol_cov,
|
|
118
|
+
b=dummy_mean,
|
|
119
|
+
U=dummy_chol_cov,
|
|
120
|
+
eta=dummy_mean,
|
|
121
|
+
Z=dummy_chol_cov,
|
|
122
|
+
ell=jnp.array(0.0),
|
|
123
|
+
),
|
|
124
|
+
model_inputs=model_inputs,
|
|
125
|
+
mean_prev=dummy_mean,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
|
|
129
|
+
dummy_state, model_inputs
|
|
130
|
+
)
|
|
131
|
+
F, c, chol_Q = linearize_moments(
|
|
132
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
observation_mean_and_chol_cov_func, observation_linearization_point, y = (
|
|
136
|
+
get_observation_params(dummy_state, model_inputs)
|
|
137
|
+
)
|
|
138
|
+
H, d, chol_R = linearize_moments(
|
|
139
|
+
observation_mean_and_chol_cov_func, observation_linearization_point
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, y)
|
|
143
|
+
|
|
144
|
+
return LinearizedKalmanFilterState(
|
|
145
|
+
elem=elem,
|
|
146
|
+
model_inputs=model_inputs,
|
|
147
|
+
mean_prev=dummy_mean,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def filter_combine(
|
|
152
|
+
state_1: LinearizedKalmanFilterState,
|
|
153
|
+
state_2: LinearizedKalmanFilterState,
|
|
154
|
+
) -> LinearizedKalmanFilterState:
|
|
155
|
+
"""Combine previous filter state with state prepared with latest model inputs.
|
|
156
|
+
|
|
157
|
+
`associative_scan` is supported but only accurate when `state` is ignored
|
|
158
|
+
in `get_dynamics_params` and `get_observation_params`.
|
|
159
|
+
|
|
160
|
+
Applies standard associative Kalman filtering operator since dynamics and observation
|
|
161
|
+
parameters are extracted in filter_prepare.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
state_1: State from previous time step.
|
|
165
|
+
state_2: State prepared (only access model_inputs attribute).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
Predicted and updated linearized moments Kalman filter state.
|
|
169
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
170
|
+
and log_normalizing_constant.
|
|
171
|
+
"""
|
|
172
|
+
combined_elem = filtering.filtering_operator(
|
|
173
|
+
state_1.elem,
|
|
174
|
+
state_2.elem,
|
|
175
|
+
)
|
|
176
|
+
return LinearizedKalmanFilterState(
|
|
177
|
+
elem=combined_elem,
|
|
178
|
+
model_inputs=state_2.model_inputs,
|
|
179
|
+
mean_prev=state_1.mean,
|
|
180
|
+
)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
r"""Linearized moments Kalman filter.
|
|
2
|
+
|
|
3
|
+
Takes a user provided conditional `mean` and `chol_cov` functions to define a
|
|
4
|
+
conditionally linear Gaussian state space model.
|
|
5
|
+
|
|
6
|
+
I.e., we approximate conditional densities as
|
|
7
|
+
|
|
8
|
+
$$
|
|
9
|
+
p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top).
|
|
10
|
+
$$
|
|
11
|
+
|
|
12
|
+
See `cuthbertlib.linearize` for more details.
|
|
13
|
+
|
|
14
|
+
Parallelism via `associative_scan` is supported, but requires the `state` argument
|
|
15
|
+
to be ignored in `get_dynamics_params` and `get_observation_params`.
|
|
16
|
+
I.e. the linearization points are pre-defined or extracted from model inputs.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from functools import partial
|
|
20
|
+
|
|
21
|
+
from cuthbert.gaussian.moments import associative_filter, non_associative_filter
|
|
22
|
+
from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
|
|
23
|
+
from cuthbert.gaussian.types import GetInitParams
|
|
24
|
+
from cuthbert.inference import Filter
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def build_filter(
|
|
28
|
+
get_init_params: GetInitParams,
|
|
29
|
+
get_dynamics_params: GetDynamicsMoments,
|
|
30
|
+
get_observation_params: GetObservationMoments,
|
|
31
|
+
associative: bool = False,
|
|
32
|
+
) -> Filter:
|
|
33
|
+
"""Build linearized moments Kalman inference filter.
|
|
34
|
+
|
|
35
|
+
If `associative` is True all filtering linearization points are pre-defined or
|
|
36
|
+
extracted from model inputs. The `state` argument should be ignored in
|
|
37
|
+
`get_dynamics_params` and `get_observation_params`.
|
|
38
|
+
|
|
39
|
+
If `associative` is False the linearization points can be extracted from the
|
|
40
|
+
previous filter state for dynamics parameters and the predict state for
|
|
41
|
+
observation parameters.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
45
|
+
get_dynamics_params: Function to get dynamics conditional mean and
|
|
46
|
+
(generalised) Cholesky covariance from linearization point and model inputs.
|
|
47
|
+
and linearization points (for the previous and current time points)
|
|
48
|
+
If `associative` is True, the `state` argument should be ignored.
|
|
49
|
+
get_observation_params: Function to get observation conditional mean,
|
|
50
|
+
(generalised) Cholesky covariance and observation from linearization point
|
|
51
|
+
and model inputs.
|
|
52
|
+
If `associative` is True, the `state` argument should be ignored.
|
|
53
|
+
associative: If True, then the filter is suitable for associative scan, but
|
|
54
|
+
assumes that the `state` is ignored in `get_dynamics_params` and
|
|
55
|
+
`get_observation_params`.
|
|
56
|
+
If False, then the filter is suitable for non-associative scan, but
|
|
57
|
+
the user is free to use the `state` to extract the linearization points.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Linearized moments Kalman filter object.
|
|
61
|
+
"""
|
|
62
|
+
if associative:
|
|
63
|
+
return Filter(
|
|
64
|
+
init_prepare=partial(
|
|
65
|
+
associative_filter.init_prepare,
|
|
66
|
+
get_init_params=get_init_params,
|
|
67
|
+
get_observation_params=get_observation_params,
|
|
68
|
+
),
|
|
69
|
+
filter_prepare=partial(
|
|
70
|
+
associative_filter.filter_prepare,
|
|
71
|
+
get_init_params=get_init_params,
|
|
72
|
+
get_dynamics_params=get_dynamics_params,
|
|
73
|
+
get_observation_params=get_observation_params,
|
|
74
|
+
),
|
|
75
|
+
filter_combine=associative_filter.filter_combine,
|
|
76
|
+
associative=True,
|
|
77
|
+
)
|
|
78
|
+
else:
|
|
79
|
+
return Filter(
|
|
80
|
+
init_prepare=partial(
|
|
81
|
+
non_associative_filter.init_prepare,
|
|
82
|
+
get_init_params=get_init_params,
|
|
83
|
+
get_observation_params=get_observation_params,
|
|
84
|
+
),
|
|
85
|
+
filter_prepare=partial(
|
|
86
|
+
non_associative_filter.filter_prepare,
|
|
87
|
+
get_init_params=get_init_params,
|
|
88
|
+
),
|
|
89
|
+
filter_combine=partial(
|
|
90
|
+
non_associative_filter.filter_combine,
|
|
91
|
+
get_dynamics_params=get_dynamics_params,
|
|
92
|
+
get_observation_params=get_observation_params,
|
|
93
|
+
),
|
|
94
|
+
associative=False,
|
|
95
|
+
)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""Implements the non-associative linearized moments Kalman filter."""
|
|
2
|
+
|
|
3
|
+
from jax import eval_shape, tree
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
|
|
7
|
+
from cuthbert.gaussian.types import GetInitParams, LinearizedKalmanFilterState
|
|
8
|
+
from cuthbert.gaussian.utils import linearized_kalman_filter_state_dummy_elem
|
|
9
|
+
from cuthbert.utils import dummy_tree_like
|
|
10
|
+
from cuthbertlib.kalman import filtering
|
|
11
|
+
from cuthbertlib.linearize import linearize_moments
|
|
12
|
+
from cuthbertlib.types import ArrayTreeLike, KeyArray
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def init_prepare(
|
|
16
|
+
model_inputs: ArrayTreeLike,
|
|
17
|
+
get_init_params: GetInitParams,
|
|
18
|
+
get_observation_params: GetObservationMoments,
|
|
19
|
+
key: KeyArray | None = None,
|
|
20
|
+
) -> LinearizedKalmanFilterState:
|
|
21
|
+
"""Prepare the initial state for the linearized moments Kalman filter.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
model_inputs: Model inputs.
|
|
25
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
26
|
+
get_observation_params: Function to get observation conditional mean,
|
|
27
|
+
(generalised) Cholesky covariance function, linearization point and
|
|
28
|
+
observation.
|
|
29
|
+
key: JAX random key - not used.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
State for the linearized moments Kalman filter.
|
|
33
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
34
|
+
and log_normalizing_constant.
|
|
35
|
+
"""
|
|
36
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
37
|
+
m0, chol_P0 = get_init_params(model_inputs)
|
|
38
|
+
|
|
39
|
+
prior_state = LinearizedKalmanFilterState(
|
|
40
|
+
elem=filtering.FilterScanElement(
|
|
41
|
+
A=jnp.zeros_like(chol_P0),
|
|
42
|
+
b=m0,
|
|
43
|
+
U=chol_P0,
|
|
44
|
+
eta=jnp.zeros_like(m0),
|
|
45
|
+
Z=jnp.zeros_like(chol_P0),
|
|
46
|
+
ell=jnp.array(0.0),
|
|
47
|
+
),
|
|
48
|
+
model_inputs=model_inputs,
|
|
49
|
+
mean_prev=dummy_tree_like(m0),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
mean_and_chol_cov_func, linearization_point, y = get_observation_params(
|
|
53
|
+
prior_state, model_inputs
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
|
|
57
|
+
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
|
|
58
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
59
|
+
mean=m,
|
|
60
|
+
chol_cov=chol_P,
|
|
61
|
+
log_normalizing_constant=ell,
|
|
62
|
+
model_inputs=model_inputs,
|
|
63
|
+
mean_prev=dummy_tree_like(m),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def filter_prepare(
|
|
68
|
+
model_inputs: ArrayTreeLike,
|
|
69
|
+
get_init_params: GetInitParams,
|
|
70
|
+
key: KeyArray | None = None,
|
|
71
|
+
) -> LinearizedKalmanFilterState:
|
|
72
|
+
"""Prepare a state for a linearized moments Kalman filter step.
|
|
73
|
+
|
|
74
|
+
Just passes through model inputs.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
model_inputs: Model inputs.
|
|
78
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
79
|
+
Only used to infer shape of mean and chol_cov.
|
|
80
|
+
key: JAX random key - not used.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Prepared state for linearized moments Kalman filter.
|
|
84
|
+
"""
|
|
85
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
86
|
+
dummy_mean_struct = eval_shape(lambda mi: get_init_params(mi)[0], model_inputs)
|
|
87
|
+
dummy_mean = dummy_tree_like(dummy_mean_struct)
|
|
88
|
+
dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
|
|
89
|
+
|
|
90
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
91
|
+
mean=dummy_mean,
|
|
92
|
+
chol_cov=dummy_chol_cov,
|
|
93
|
+
log_normalizing_constant=jnp.array(0.0),
|
|
94
|
+
model_inputs=model_inputs,
|
|
95
|
+
mean_prev=dummy_mean,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def filter_combine(
|
|
100
|
+
state_1: LinearizedKalmanFilterState,
|
|
101
|
+
state_2: LinearizedKalmanFilterState,
|
|
102
|
+
get_dynamics_params: GetDynamicsMoments,
|
|
103
|
+
get_observation_params: GetObservationMoments,
|
|
104
|
+
) -> LinearizedKalmanFilterState:
|
|
105
|
+
"""Combine previous filter state with state prepared with latest model inputs.
|
|
106
|
+
|
|
107
|
+
Applies linearized moments Kalman predict + filter update in covariance square
|
|
108
|
+
root form.
|
|
109
|
+
Not suitable for associative scan.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
state_1: State from previous time step.
|
|
113
|
+
state_2: State prepared (only access model_inputs attribute).
|
|
114
|
+
get_dynamics_params: Function to get dynamics conditional mean and
|
|
115
|
+
(generalised) Cholesky covariance from linearization point and model inputs.
|
|
116
|
+
get_observation_params: Function to get observation conditional mean,
|
|
117
|
+
(generalised) Cholesky covariance and observation from linearization point
|
|
118
|
+
and model inputs.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Predicted and updated linearized moments Kalman filter state.
|
|
122
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
123
|
+
and log_normalizing_constant.
|
|
124
|
+
"""
|
|
125
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
|
|
126
|
+
state_1, state_2.model_inputs
|
|
127
|
+
)
|
|
128
|
+
F, c, chol_Q = linearize_moments(
|
|
129
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
predict_mean, predict_chol_cov = filtering.predict(
|
|
133
|
+
state_1.mean, state_1.chol_cov, F, c, chol_Q
|
|
134
|
+
)
|
|
135
|
+
predict_state = linearized_kalman_filter_state_dummy_elem(
|
|
136
|
+
mean=predict_mean,
|
|
137
|
+
chol_cov=predict_chol_cov,
|
|
138
|
+
log_normalizing_constant=state_1.log_normalizing_constant,
|
|
139
|
+
model_inputs=state_2.model_inputs,
|
|
140
|
+
mean_prev=state_1.mean,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
observation_mean_and_chol_cov_func, observation_linearization_point, y = (
|
|
144
|
+
get_observation_params(predict_state, state_2.model_inputs)
|
|
145
|
+
)
|
|
146
|
+
H, d, chol_R = linearize_moments(
|
|
147
|
+
observation_mean_and_chol_cov_func, observation_linearization_point
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
(update_mean, update_chol_cov), log_normalizing_constant = filtering.update(
|
|
151
|
+
predict_mean, predict_chol_cov, H, d, chol_R, y
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return linearized_kalman_filter_state_dummy_elem(
|
|
155
|
+
mean=update_mean,
|
|
156
|
+
chol_cov=update_chol_cov,
|
|
157
|
+
log_normalizing_constant=state_1.log_normalizing_constant
|
|
158
|
+
+ log_normalizing_constant,
|
|
159
|
+
model_inputs=state_2.model_inputs,
|
|
160
|
+
mean_prev=state_1.mean,
|
|
161
|
+
)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
r"""Linearized moments Kalman smoother.
|
|
2
|
+
|
|
3
|
+
Takes a user provided conditional `mean` and `chol_cov` functions to define a
|
|
4
|
+
conditionally linear Gaussian state space model.
|
|
5
|
+
|
|
6
|
+
I.e., we approximate conditional densities as
|
|
7
|
+
|
|
8
|
+
$$
|
|
9
|
+
p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top).
|
|
10
|
+
$$
|
|
11
|
+
|
|
12
|
+
See `cuthbertlib.linearize` for more details.
|
|
13
|
+
|
|
14
|
+
Parallelism via `associative_scan` is supported, but requires the `state` argument
|
|
15
|
+
to be ignored in `get_dynamics_params`.
|
|
16
|
+
I.e. the linearization points are pre-defined or extracted from model inputs.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from functools import partial
|
|
20
|
+
|
|
21
|
+
from jax import numpy as jnp
|
|
22
|
+
from jax import tree
|
|
23
|
+
|
|
24
|
+
from cuthbert.gaussian.kalman import (
|
|
25
|
+
KalmanSmootherState,
|
|
26
|
+
convert_filter_to_smoother_state,
|
|
27
|
+
smoother_combine,
|
|
28
|
+
)
|
|
29
|
+
from cuthbert.gaussian.moments.types import GetDynamicsMoments
|
|
30
|
+
from cuthbert.gaussian.types import LinearizedKalmanFilterState
|
|
31
|
+
from cuthbert.inference import Smoother
|
|
32
|
+
from cuthbertlib.kalman import smoothing
|
|
33
|
+
from cuthbertlib.linearize import linearize_moments
|
|
34
|
+
from cuthbertlib.types import ArrayTreeLike, KeyArray
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_smoother(
|
|
38
|
+
get_dynamics_params: GetDynamicsMoments,
|
|
39
|
+
store_gain: bool = False,
|
|
40
|
+
store_chol_cov_given_next: bool = False,
|
|
41
|
+
) -> Smoother:
|
|
42
|
+
"""Build linearized moments Kalman inference smoother for conditionally Gaussian SSMs.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
get_dynamics_params: Function to get dynamics conditional mean and
|
|
46
|
+
(generalised) Cholesky covariance from linearization point and model inputs.
|
|
47
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
48
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
49
|
+
in the smoother state.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Linearized moments Kalman smoother object, suitable for associative scan.
|
|
53
|
+
"""
|
|
54
|
+
return Smoother(
|
|
55
|
+
smoother_prepare=partial(
|
|
56
|
+
smoother_prepare,
|
|
57
|
+
get_dynamics_params=get_dynamics_params,
|
|
58
|
+
store_gain=store_gain,
|
|
59
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
60
|
+
),
|
|
61
|
+
smoother_combine=smoother_combine,
|
|
62
|
+
convert_filter_to_smoother_state=partial(
|
|
63
|
+
convert_filter_to_smoother_state,
|
|
64
|
+
store_gain=store_gain,
|
|
65
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
66
|
+
),
|
|
67
|
+
associative=True,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def smoother_prepare(
|
|
72
|
+
filter_state: LinearizedKalmanFilterState,
|
|
73
|
+
get_dynamics_params: GetDynamicsMoments,
|
|
74
|
+
model_inputs: ArrayTreeLike,
|
|
75
|
+
store_gain: bool = False,
|
|
76
|
+
store_chol_cov_given_next: bool = False,
|
|
77
|
+
key: KeyArray | None = None,
|
|
78
|
+
) -> KalmanSmootherState:
|
|
79
|
+
"""Prepare a state for an extended Kalman smoother step.
|
|
80
|
+
|
|
81
|
+
Note that the model_inputs here are different to filter_state.model_inputs.
|
|
82
|
+
The model_inputs required here are for the transition from t to t+1.
|
|
83
|
+
filter_state.model_inputs represents the transition from t-1 to t.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
filter_state: State generated by the extended Kalman filter at time t.
|
|
87
|
+
get_dynamics_params: Function to get dynamics conditional mean and
|
|
88
|
+
(generalised) Cholesky covariance from linearization point and model inputs.
|
|
89
|
+
model_inputs: Model inputs for the transition from t to t+1.
|
|
90
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
91
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
92
|
+
in the smoother state.
|
|
93
|
+
key: JAX random key - not used.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Prepared state for the Kalman smoother.
|
|
97
|
+
"""
|
|
98
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
99
|
+
filter_mean = filter_state.mean
|
|
100
|
+
filter_chol_cov = filter_state.chol_cov
|
|
101
|
+
|
|
102
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
|
|
103
|
+
filter_state, model_inputs
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
F, c, chol_Q = linearize_moments(
|
|
107
|
+
dynamics_mean_and_chol_cov_func, dynamics_linearization_point
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
state = smoothing.associative_params_single(
|
|
111
|
+
filter_mean, filter_chol_cov, F, c, chol_Q
|
|
112
|
+
)
|
|
113
|
+
return KalmanSmootherState(
|
|
114
|
+
elem=state,
|
|
115
|
+
gain=state.E if store_gain else None,
|
|
116
|
+
chol_cov_given_next=state.D if store_chol_cov_given_next else None,
|
|
117
|
+
model_inputs=model_inputs,
|
|
118
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Provides types for the moment-based linearization of Gaussian state-space models."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from cuthbert.gaussian.types import LinearizedKalmanFilterState
|
|
6
|
+
from cuthbertlib.linearize.moments import MeanAndCholCovFunc
|
|
7
|
+
from cuthbertlib.types import Array, ArrayTreeLike
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GetDynamicsMoments(Protocol):
|
|
11
|
+
"""Protocol for extracting the dynamics specifications."""
|
|
12
|
+
|
|
13
|
+
def __call__(
|
|
14
|
+
self,
|
|
15
|
+
state: LinearizedKalmanFilterState,
|
|
16
|
+
model_inputs: ArrayTreeLike,
|
|
17
|
+
) -> tuple[MeanAndCholCovFunc, Array]:
|
|
18
|
+
"""Get dynamics conditional mean and chol_cov function and linearization point.
|
|
19
|
+
|
|
20
|
+
`associative_scan` only supported when `state` is ignored.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
state: NamedTuple containing `mean` and `mean_prev` attributes.
|
|
24
|
+
model_inputs: Model inputs.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tuple with dynamics conditional mean and (generalised) Cholesky covariance
|
|
28
|
+
function and linearization point.
|
|
29
|
+
"""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GetObservationMoments(Protocol):
|
|
34
|
+
"""Protocol for extracting the observation specifications."""
|
|
35
|
+
|
|
36
|
+
def __call__(
|
|
37
|
+
self, state: LinearizedKalmanFilterState, model_inputs: ArrayTreeLike
|
|
38
|
+
) -> tuple[MeanAndCholCovFunc, Array, Array]:
|
|
39
|
+
"""Get conditional mean and chol_cov function, linearization point and observation.
|
|
40
|
+
|
|
41
|
+
`associative_scan` only supported when `state` input is ignored.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
state: NamedTuple containing `mean` and `mean_prev` attributes.
|
|
45
|
+
model_inputs: Model inputs.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Tuple with conditional mean and chol_cov function, linearization point
|
|
49
|
+
and observation.
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from cuthbert.gaussian.taylor import (
|
|
2
|
+
associative_filter,
|
|
3
|
+
non_associative_filter,
|
|
4
|
+
smoother,
|
|
5
|
+
)
|
|
6
|
+
from cuthbert.gaussian.taylor.filter import build_filter
|
|
7
|
+
from cuthbert.gaussian.taylor.smoother import build_smoother
|
|
8
|
+
from cuthbert.gaussian.taylor.types import (
|
|
9
|
+
GetDynamicsLogDensity,
|
|
10
|
+
GetInitLogDensity,
|
|
11
|
+
GetObservationFunc,
|
|
12
|
+
LogPotential,
|
|
13
|
+
)
|
|
14
|
+
from cuthbert.gaussian.types import LinearizedKalmanFilterState
|