cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.2.dist-info/RECORD +0 -12
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
"""Parallel-in-time Bayesian filter for discrete hidden Markov models.
|
|
2
|
+
|
|
3
|
+
References:
|
|
4
|
+
- https://ieeexplore.ieee.org/document/9512397
|
|
5
|
+
- https://github.com/EEA-sensors/sequential-parallelization-examples/tree/main/python/temporal-parallelization-inference-in-HMMs
|
|
6
|
+
- https://github.com/probml/dynamax/blob/main/dynamax/hidden_markov_model/parallel_inference.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from functools import partial
|
|
10
|
+
from typing import NamedTuple
|
|
11
|
+
|
|
12
|
+
import jax.numpy as jnp
|
|
13
|
+
from jax import tree
|
|
14
|
+
|
|
15
|
+
from cuthbert.discrete.types import (
|
|
16
|
+
GetInitDist,
|
|
17
|
+
GetObsLogLikelihoods,
|
|
18
|
+
GetTransitionMatrix,
|
|
19
|
+
)
|
|
20
|
+
from cuthbert.inference import Filter
|
|
21
|
+
from cuthbertlib.discrete import filtering
|
|
22
|
+
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DiscreteFilterState(NamedTuple):
|
|
26
|
+
"""Discrete filter state."""
|
|
27
|
+
|
|
28
|
+
elem: filtering.FilterScanElement
|
|
29
|
+
model_inputs: ArrayTree
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def dist(self) -> Array:
|
|
33
|
+
"""The filtered distribution.
|
|
34
|
+
|
|
35
|
+
Has shape (K,) or (T+1, K) where K is the number of possible states.
|
|
36
|
+
"""
|
|
37
|
+
return jnp.take(self.elem.f, 0, axis=-2)
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def log_normalizing_constant(self) -> Array:
|
|
41
|
+
"""Log normalizing constant (cumulative)."""
|
|
42
|
+
return jnp.take(self.elem.log_g, 0, axis=-1)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def build_filter(
|
|
46
|
+
get_init_dist: GetInitDist,
|
|
47
|
+
get_trans_matrix: GetTransitionMatrix,
|
|
48
|
+
get_obs_lls: GetObsLogLikelihoods,
|
|
49
|
+
) -> Filter:
|
|
50
|
+
r"""Builds a filter object for discrete hidden Markov models.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
get_init_dist: Function to get initial state probabilities $m_i = p(x_0 = i)$.
|
|
54
|
+
get_trans_matrix: Function to get the transition matrix $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
|
|
55
|
+
get_obs_lls: Function to get observation log likelihoods $b_i = \log p(y_t | x_t = i)$.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Filter object. Suitable for associative scan.
|
|
59
|
+
"""
|
|
60
|
+
return Filter(
|
|
61
|
+
init_prepare=partial(
|
|
62
|
+
init_prepare, get_init_dist=get_init_dist, get_obs_lls=get_obs_lls
|
|
63
|
+
),
|
|
64
|
+
filter_prepare=partial(
|
|
65
|
+
filter_prepare, get_trans_matrix=get_trans_matrix, get_obs_lls=get_obs_lls
|
|
66
|
+
),
|
|
67
|
+
filter_combine=filter_combine,
|
|
68
|
+
associative=True,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def init_prepare(
|
|
73
|
+
model_inputs: ArrayTreeLike,
|
|
74
|
+
get_init_dist: GetInitDist,
|
|
75
|
+
get_obs_lls: GetObsLogLikelihoods,
|
|
76
|
+
key: KeyArray | None = None,
|
|
77
|
+
) -> DiscreteFilterState:
|
|
78
|
+
"""Prepare the initial state for the filter.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
model_inputs: Model inputs.
|
|
82
|
+
get_init_dist: Function to get initial state probabilities m_i = p(x_0 = i).
|
|
83
|
+
get_obs_lls: Function to get observation log likelihoods b_i = log p(y_t | x_t = i).
|
|
84
|
+
key: JAX random key - not used.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Prepared state for the filter.
|
|
88
|
+
"""
|
|
89
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
90
|
+
init_dist = get_init_dist(model_inputs)
|
|
91
|
+
obs_lls = get_obs_lls(model_inputs)
|
|
92
|
+
f, log_g = filtering.condition_on_obs(init_dist, obs_lls)
|
|
93
|
+
N = init_dist.shape[0]
|
|
94
|
+
f *= jnp.ones((N, N))
|
|
95
|
+
log_g *= jnp.ones(N)
|
|
96
|
+
return DiscreteFilterState(
|
|
97
|
+
elem=filtering.FilterScanElement(f, log_g), model_inputs=model_inputs
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def filter_prepare(
|
|
102
|
+
model_inputs: ArrayTreeLike,
|
|
103
|
+
get_trans_matrix: GetTransitionMatrix,
|
|
104
|
+
get_obs_lls: GetObsLogLikelihoods,
|
|
105
|
+
key: KeyArray | None = None,
|
|
106
|
+
) -> DiscreteFilterState:
|
|
107
|
+
"""Prepare a state for a filter step.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_inputs: Model inputs.
|
|
111
|
+
get_trans_matrix: Function to get the transition matrix A_{ij} = p(x_t = j | x_{t-1} = i).
|
|
112
|
+
get_obs_lls: Function to get observation log likelihoods b_i = log p(y_t | x_t = i).
|
|
113
|
+
key: JAX random key - not used.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Prepared state for the filter.
|
|
117
|
+
"""
|
|
118
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
119
|
+
trans_matrix = get_trans_matrix(model_inputs)
|
|
120
|
+
obs_lls = get_obs_lls(model_inputs)
|
|
121
|
+
f, log_g = filtering.condition_on_obs(trans_matrix, obs_lls)
|
|
122
|
+
return DiscreteFilterState(
|
|
123
|
+
elem=filtering.FilterScanElement(f, log_g), model_inputs=model_inputs
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def filter_combine(
|
|
128
|
+
state_1: DiscreteFilterState, state_2: DiscreteFilterState
|
|
129
|
+
) -> DiscreteFilterState:
|
|
130
|
+
"""Combine previous filter state with state prepared with latest model inputs.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
state_1: State from previous time step.
|
|
134
|
+
state_2: State prepared (only access model_inputs attribute).
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Combined filter state. Contains distribution and log_normalizing_constant.
|
|
138
|
+
"""
|
|
139
|
+
combined_elem = filtering.filtering_operator(state_1.elem, state_2.elem)
|
|
140
|
+
return DiscreteFilterState(elem=combined_elem, model_inputs=state_2.model_inputs)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Parallel-in-time Bayesian smoother for discrete hidden Markov models.
|
|
2
|
+
|
|
3
|
+
References:
|
|
4
|
+
- https://ieeexplore.ieee.org/document/9512397
|
|
5
|
+
- https://github.com/EEA-sensors/sequential-parallelization-examples/tree/main/python/temporal-parallelization-inference-in-HMMs
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import NamedTuple
|
|
10
|
+
|
|
11
|
+
import jax.numpy as jnp
|
|
12
|
+
from jax import tree
|
|
13
|
+
|
|
14
|
+
from cuthbert.discrete.filter import DiscreteFilterState
|
|
15
|
+
from cuthbert.discrete.types import GetTransitionMatrix
|
|
16
|
+
from cuthbert.inference import Smoother
|
|
17
|
+
from cuthbert.utils import dummy_tree_like
|
|
18
|
+
from cuthbertlib.discrete.smoothing import get_reverse_kernel, smoothing_operator
|
|
19
|
+
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DiscreteSmootherState(NamedTuple):
|
|
23
|
+
"""Discrete smoother state."""
|
|
24
|
+
|
|
25
|
+
a: Array
|
|
26
|
+
model_inputs: ArrayTree
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def dist(self):
|
|
30
|
+
"""The smoothed distribution.
|
|
31
|
+
|
|
32
|
+
Has shape (K,) or (T+1, K) where K is the number of possible states.
|
|
33
|
+
"""
|
|
34
|
+
return jnp.take(self.a, 0, axis=-2)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_smoother(get_trans_matrix: GetTransitionMatrix) -> Smoother:
|
|
38
|
+
r"""Builds a smoother object for discrete hidden Markov models.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
get_trans_matrix: Function to get the transition matrix $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Smoother object. Suitable for associative scan.
|
|
45
|
+
"""
|
|
46
|
+
return Smoother(
|
|
47
|
+
convert_filter_to_smoother_state=convert_filter_to_smoother_state,
|
|
48
|
+
smoother_prepare=partial(smoother_prepare, get_trans_matrix=get_trans_matrix),
|
|
49
|
+
smoother_combine=smoother_combine,
|
|
50
|
+
associative=True,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def smoother_prepare(
|
|
55
|
+
filter_state: DiscreteFilterState,
|
|
56
|
+
get_trans_matrix: GetTransitionMatrix,
|
|
57
|
+
model_inputs: ArrayTreeLike,
|
|
58
|
+
key: KeyArray | None = None,
|
|
59
|
+
) -> DiscreteSmootherState:
|
|
60
|
+
"""Prepare a state for a smoother step.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
filter_state: State generated by the filter at time t.
|
|
64
|
+
get_trans_matrix: Function to get the transition matrix A_{ij} = p(x_{t+1} = j | x_{t} = i).
|
|
65
|
+
model_inputs: Model inputs for the transition from t to t+1.
|
|
66
|
+
key: JAX random key - not used.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Prepared state for the smoother.
|
|
70
|
+
"""
|
|
71
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
72
|
+
trans_matrix = get_trans_matrix(model_inputs)
|
|
73
|
+
a = get_reverse_kernel(filter_state.dist, trans_matrix)
|
|
74
|
+
return DiscreteSmootherState(a=a, model_inputs=model_inputs)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def convert_filter_to_smoother_state(
|
|
78
|
+
filter_state: DiscreteFilterState,
|
|
79
|
+
model_inputs: ArrayTreeLike | None = None,
|
|
80
|
+
key: KeyArray | None = None,
|
|
81
|
+
) -> DiscreteSmootherState:
|
|
82
|
+
"""Convert a filter state to a smoother state.
|
|
83
|
+
|
|
84
|
+
Useful for the final filter state which is equivalent to the final smoother state.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
filter_state: Filter state.
|
|
88
|
+
model_inputs: Only used to create an empty model_inputs tree (the values are ignored).
|
|
89
|
+
Useful so that the final smoother state has the same structure as the rest.
|
|
90
|
+
By default, filter_state.model_inputs is used. So this is only needed if the
|
|
91
|
+
smoother model_inputs have a different tree structure to filter_state.model_inputs.
|
|
92
|
+
key: JAX random key - not used.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Smoother state, same data as filter state just different structure.
|
|
96
|
+
Note that the model_inputs are set to dummy values.
|
|
97
|
+
"""
|
|
98
|
+
if model_inputs is None:
|
|
99
|
+
model_inputs = filter_state.model_inputs
|
|
100
|
+
|
|
101
|
+
dummy_model_inputs = dummy_tree_like(model_inputs)
|
|
102
|
+
|
|
103
|
+
filter_dist = filter_state.dist
|
|
104
|
+
a = jnp.tile(filter_dist, (filter_dist.shape[0], 1))
|
|
105
|
+
return DiscreteSmootherState(a=a, model_inputs=dummy_model_inputs)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def smoother_combine(
|
|
109
|
+
state_1: DiscreteSmootherState, state_2: DiscreteSmootherState
|
|
110
|
+
) -> DiscreteSmootherState:
|
|
111
|
+
"""Combine smoother state from next time point with state prepared with latest model inputs.
|
|
112
|
+
|
|
113
|
+
Remember smoothing iterates backwards in time.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
state_1: State prepared with model inputs at time t.
|
|
117
|
+
state_2: Smoother state at time t + 1.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Combined smoother state.
|
|
121
|
+
"""
|
|
122
|
+
a = smoothing_operator(state_1.a, state_2.a)
|
|
123
|
+
return DiscreteSmootherState(a=a, model_inputs=state_1.model_inputs)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Provides types for representing discrete HMMs."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.types import Array, ArrayTreeLike
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class GetInitDist(Protocol):
|
|
9
|
+
"""Protocol for specifying the initial distribution."""
|
|
10
|
+
|
|
11
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> Array:
|
|
12
|
+
"""Get the initial distribution.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model_inputs: Model inputs.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
An array $m$ of shape (N,) where N is the number of states,
|
|
19
|
+
with $m_i = p(x_0 = i)$.
|
|
20
|
+
"""
|
|
21
|
+
...
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class GetTransitionMatrix(Protocol):
|
|
25
|
+
"""Protocol for specifying the transition matrix."""
|
|
26
|
+
|
|
27
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> Array:
|
|
28
|
+
r"""Get the transition matrix.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
model_inputs: Model inputs.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
An array $A$ of shape (N, N) where N is the number of
|
|
35
|
+
states, with $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
|
|
36
|
+
"""
|
|
37
|
+
...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class GetObsLogLikelihoods(Protocol):
|
|
41
|
+
"""Protocol for specifying the observation log likelihoods."""
|
|
42
|
+
|
|
43
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> Array:
|
|
44
|
+
r"""Get the observation log likelihoods.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model_inputs: Model inputs.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
An array $b$ of shape (N,) where N is the number of states,
|
|
51
|
+
with $b_i = \log p(y_t \mid x_t = i)$.
|
|
52
|
+
"""
|
|
53
|
+
...
|
|
File without changes
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""Implements the square-root, parallel-in-time Kalman filter for linear Gaussian SSMs.
|
|
2
|
+
|
|
3
|
+
See [Yaghoobi et. al. (2025)](https://doi.org/10.1137/23M156121X).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import NamedTuple
|
|
8
|
+
|
|
9
|
+
from jax import numpy as jnp
|
|
10
|
+
from jax import tree
|
|
11
|
+
|
|
12
|
+
from cuthbert.gaussian.types import (
|
|
13
|
+
GetDynamicsParams,
|
|
14
|
+
GetInitParams,
|
|
15
|
+
GetObservationParams,
|
|
16
|
+
)
|
|
17
|
+
from cuthbert.inference import Filter, Smoother
|
|
18
|
+
from cuthbert.utils import dummy_tree_like
|
|
19
|
+
from cuthbertlib.kalman import filtering, smoothing
|
|
20
|
+
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class KalmanFilterState(NamedTuple):
|
|
24
|
+
"""Kalman filter state."""
|
|
25
|
+
|
|
26
|
+
elem: filtering.FilterScanElement
|
|
27
|
+
model_inputs: ArrayTree
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def mean(self) -> Array:
|
|
31
|
+
"""Filtering mean."""
|
|
32
|
+
return self.elem.b
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def chol_cov(self) -> Array:
|
|
36
|
+
"""Filtering generalised Cholesky covariance."""
|
|
37
|
+
return self.elem.U
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def log_normalizing_constant(self) -> Array:
|
|
41
|
+
"""Log normalizing constant (cumulative)."""
|
|
42
|
+
return self.elem.ell
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class KalmanSmootherState(NamedTuple):
|
|
46
|
+
"""Kalman smoother state."""
|
|
47
|
+
|
|
48
|
+
elem: smoothing.SmootherScanElement
|
|
49
|
+
model_inputs: ArrayTree
|
|
50
|
+
gain: Array | None = None
|
|
51
|
+
chol_cov_given_next: Array | None = None
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def mean(self) -> Array:
|
|
55
|
+
"""Smoothing mean."""
|
|
56
|
+
return self.elem.g
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def chol_cov(self) -> Array:
|
|
60
|
+
"""Smoothing generalised Cholesky covariance."""
|
|
61
|
+
return self.elem.D
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_filter(
|
|
65
|
+
get_init_params: GetInitParams,
|
|
66
|
+
get_dynamics_params: GetDynamicsParams,
|
|
67
|
+
get_observation_params: GetObservationParams,
|
|
68
|
+
) -> Filter:
|
|
69
|
+
"""Builds an exact Kalman filter object for linear Gaussian SSMs.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
get_init_params: Function to get m0, chol_P0 to initialize filter state,
|
|
73
|
+
given model inputs sufficient to define p(x_0) = N(m0, chol_P0 @ chol_P0^T).
|
|
74
|
+
get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q
|
|
75
|
+
given model inputs sufficient to define
|
|
76
|
+
p(x_t | x_{t-1}) = N(F @ x_{t-1} + c, chol_Q @ chol_Q^T).
|
|
77
|
+
get_observation_params: Function to get observation parameters, H, d, chol_R, y
|
|
78
|
+
given model inputs sufficient to define
|
|
79
|
+
p(y_t | x_t) = N(H @ x_t + d, chol_R @ chol_R^T).
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Filter object for exact Kalman filter. Suitable for associative scan.
|
|
83
|
+
"""
|
|
84
|
+
return Filter(
|
|
85
|
+
init_prepare=partial(
|
|
86
|
+
init_prepare,
|
|
87
|
+
get_init_params=get_init_params,
|
|
88
|
+
get_observation_params=get_observation_params,
|
|
89
|
+
),
|
|
90
|
+
filter_prepare=partial(
|
|
91
|
+
filter_prepare,
|
|
92
|
+
get_dynamics_params=get_dynamics_params,
|
|
93
|
+
get_observation_params=get_observation_params,
|
|
94
|
+
),
|
|
95
|
+
filter_combine=filter_combine,
|
|
96
|
+
associative=True,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def build_smoother(
|
|
101
|
+
get_dynamics_params: GetDynamicsParams,
|
|
102
|
+
store_gain: bool = False,
|
|
103
|
+
store_chol_cov_given_next: bool = False,
|
|
104
|
+
) -> Smoother:
|
|
105
|
+
"""Builds an exact Kalman smoother object for linear Gaussian SSMs.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q
|
|
109
|
+
given model inputs sufficient to define
|
|
110
|
+
p(x_t | x_{t-1}) = N(F @ x_{t-1} + c, chol_Q @ chol_Q^T).
|
|
111
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
112
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
113
|
+
in the smoother state.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Smoother object for exact Kalman smoother. Suitable for associative scan.
|
|
117
|
+
"""
|
|
118
|
+
return Smoother(
|
|
119
|
+
convert_filter_to_smoother_state=partial(
|
|
120
|
+
convert_filter_to_smoother_state,
|
|
121
|
+
store_gain=store_gain,
|
|
122
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
123
|
+
),
|
|
124
|
+
smoother_prepare=partial(
|
|
125
|
+
smoother_prepare,
|
|
126
|
+
get_dynamics_params=get_dynamics_params,
|
|
127
|
+
store_gain=store_gain,
|
|
128
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
129
|
+
),
|
|
130
|
+
smoother_combine=smoother_combine,
|
|
131
|
+
associative=True,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def init_prepare(
|
|
136
|
+
model_inputs: ArrayTreeLike,
|
|
137
|
+
get_init_params: GetInitParams,
|
|
138
|
+
get_observation_params: GetObservationParams,
|
|
139
|
+
key: KeyArray | None = None,
|
|
140
|
+
) -> KalmanFilterState:
|
|
141
|
+
"""Prepare the initial state for the Kalman filter.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
model_inputs: Model inputs.
|
|
145
|
+
get_init_params: Function to get m0, chol_P0 from model inputs.
|
|
146
|
+
get_observation_params: Function to get observation parameters, H, d, chol_R, y.
|
|
147
|
+
key: JAX random key - not used.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
State for the Kalman filter.
|
|
151
|
+
Contains mean and chol_cov (generalised Cholesky factor of covariance).
|
|
152
|
+
"""
|
|
153
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
154
|
+
m0, chol_P0 = get_init_params(model_inputs)
|
|
155
|
+
H, d, chol_R, y = get_observation_params(model_inputs)
|
|
156
|
+
|
|
157
|
+
(m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
|
|
158
|
+
elem = filtering.FilterScanElement(
|
|
159
|
+
A=jnp.zeros_like(chol_P),
|
|
160
|
+
b=m,
|
|
161
|
+
U=chol_P,
|
|
162
|
+
eta=jnp.zeros_like(m),
|
|
163
|
+
Z=jnp.zeros_like(chol_P),
|
|
164
|
+
ell=ell,
|
|
165
|
+
)
|
|
166
|
+
return KalmanFilterState(elem=elem, model_inputs=model_inputs)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def filter_prepare(
|
|
170
|
+
model_inputs: ArrayTreeLike,
|
|
171
|
+
get_dynamics_params: GetDynamicsParams,
|
|
172
|
+
get_observation_params: GetObservationParams,
|
|
173
|
+
key: KeyArray | None = None,
|
|
174
|
+
) -> KalmanFilterState:
|
|
175
|
+
"""Prepare a state for an exact Kalman filter step.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
model_inputs: Model inputs.
|
|
179
|
+
get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q.
|
|
180
|
+
get_observation_params: Function to get observation parameters, H, d, chol_R, y.
|
|
181
|
+
key: JAX random key - not used.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Prepared state for Kalman filter.
|
|
185
|
+
"""
|
|
186
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
187
|
+
F, c, chol_Q = get_dynamics_params(model_inputs)
|
|
188
|
+
H, d, chol_R, y = get_observation_params(model_inputs)
|
|
189
|
+
elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, y)
|
|
190
|
+
return KalmanFilterState(elem=elem, model_inputs=model_inputs)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def filter_combine(
|
|
194
|
+
state_1: KalmanFilterState,
|
|
195
|
+
state_2: KalmanFilterState,
|
|
196
|
+
) -> KalmanFilterState:
|
|
197
|
+
"""Combine previous filter state with state prepared with latest model inputs.
|
|
198
|
+
|
|
199
|
+
Applies exact Kalman predict + filter update in covariance square root form.
|
|
200
|
+
Suitable for associative scan (as well as sequential scan).
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
state_1: State from previous time step.
|
|
204
|
+
state_2: State prepared with latest model inputs.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
Combined Kalman filter state.
|
|
208
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
209
|
+
and log_normalizing_constant.
|
|
210
|
+
"""
|
|
211
|
+
combined_elem = filtering.filtering_operator(
|
|
212
|
+
state_1.elem,
|
|
213
|
+
state_2.elem,
|
|
214
|
+
)
|
|
215
|
+
return KalmanFilterState(elem=combined_elem, model_inputs=state_2.model_inputs)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def smoother_prepare(
|
|
219
|
+
filter_state: KalmanFilterState,
|
|
220
|
+
get_dynamics_params: GetDynamicsParams,
|
|
221
|
+
model_inputs: ArrayTreeLike,
|
|
222
|
+
store_gain: bool = False,
|
|
223
|
+
store_chol_cov_given_next: bool = False,
|
|
224
|
+
key: KeyArray | None = None,
|
|
225
|
+
) -> KalmanSmootherState:
|
|
226
|
+
"""Prepare a state for an exact Kalman smoother step.
|
|
227
|
+
|
|
228
|
+
Note that the model_inputs here are different to filter_state.model_inputs.
|
|
229
|
+
The model_inputs required here are for the transition from t to t+1.
|
|
230
|
+
filter_state.model_inputs represents the transition from t-1 to t.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
filter_state: State generated by the Kalman filter at time t.
|
|
234
|
+
get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q,
|
|
235
|
+
from model inputs.
|
|
236
|
+
model_inputs: Model inputs for the transition from t to t+1.
|
|
237
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
238
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
239
|
+
in the smoother state.
|
|
240
|
+
key: JAX random key - not used.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Prepared state for the Kalman smoother.
|
|
244
|
+
"""
|
|
245
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
246
|
+
F, c, chol_Q = get_dynamics_params(model_inputs)
|
|
247
|
+
filter_mean = filter_state.mean
|
|
248
|
+
filter_chol_cov = filter_state.chol_cov
|
|
249
|
+
state = smoothing.associative_params_single(
|
|
250
|
+
filter_mean, filter_chol_cov, F, c, chol_Q
|
|
251
|
+
)
|
|
252
|
+
return KalmanSmootherState(
|
|
253
|
+
elem=state,
|
|
254
|
+
gain=state.E if store_gain else None,
|
|
255
|
+
chol_cov_given_next=state.D if store_chol_cov_given_next else None,
|
|
256
|
+
model_inputs=model_inputs,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def smoother_combine(
|
|
261
|
+
state_1: KalmanSmootherState,
|
|
262
|
+
state_2: KalmanSmootherState,
|
|
263
|
+
) -> KalmanSmootherState:
|
|
264
|
+
"""Combine smoother state from next time point with state prepared with latest model inputs.
|
|
265
|
+
|
|
266
|
+
Remember smoothing iterates backwards in time.
|
|
267
|
+
|
|
268
|
+
Applies exact Kalman smoother update in covariance square root form.
|
|
269
|
+
Suitable for associative scan (as well as sequential scan).
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
state_1: State prepared with model inputs at time t.
|
|
273
|
+
state_2: Smoother state at time t + 1.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
Combined Kalman smoother state.
|
|
277
|
+
Contains mean, chol_cov (generalised Cholesky factor of covariance)
|
|
278
|
+
and gain (which can be used to compute temporal cross-covariance).
|
|
279
|
+
"""
|
|
280
|
+
state_elem = smoothing.smoothing_operator(
|
|
281
|
+
state_2.elem,
|
|
282
|
+
state_1.elem,
|
|
283
|
+
)
|
|
284
|
+
return KalmanSmootherState(
|
|
285
|
+
elem=state_elem,
|
|
286
|
+
gain=state_1.gain,
|
|
287
|
+
chol_cov_given_next=state_1.chol_cov_given_next,
|
|
288
|
+
model_inputs=state_1.model_inputs,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def convert_filter_to_smoother_state(
|
|
293
|
+
filter_state: ArrayTreeLike,
|
|
294
|
+
model_inputs: ArrayTreeLike | None = None,
|
|
295
|
+
store_gain: bool = False,
|
|
296
|
+
store_chol_cov_given_next: bool = False,
|
|
297
|
+
key: KeyArray | None = None,
|
|
298
|
+
) -> KalmanSmootherState:
|
|
299
|
+
"""Convert the filter state to a smoother state.
|
|
300
|
+
|
|
301
|
+
Useful for the final filter state which is equivalent to the final smoother state.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
filter_state: Filter state.
|
|
305
|
+
model_inputs: Only used to create an empty model_inputs tree
|
|
306
|
+
(the values are ignored).
|
|
307
|
+
Useful so that the final smoother state has the same structure as the rest.
|
|
308
|
+
By default, filter_state.model_inputs is used. So this
|
|
309
|
+
is only needed if the smoother model_inputs have a different tree
|
|
310
|
+
structure to filter_state.model_inputs.
|
|
311
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
312
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
313
|
+
in the smoother state.
|
|
314
|
+
key: JAX random key - not used.
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Smoother state, same data as filter state just different structure.
|
|
318
|
+
Note that the model_inputs are set to dummy values.
|
|
319
|
+
"""
|
|
320
|
+
if model_inputs is None:
|
|
321
|
+
model_inputs = filter_state.model_inputs
|
|
322
|
+
|
|
323
|
+
dummy_model_inputs = dummy_tree_like(model_inputs)
|
|
324
|
+
|
|
325
|
+
elem = smoothing.SmootherScanElement(
|
|
326
|
+
g=filter_state.mean,
|
|
327
|
+
D=filter_state.chol_cov,
|
|
328
|
+
E=jnp.zeros_like(filter_state.chol_cov),
|
|
329
|
+
)
|
|
330
|
+
return KalmanSmootherState(
|
|
331
|
+
elem=elem,
|
|
332
|
+
gain=dummy_tree_like(filter_state.chol_cov) if store_gain else None,
|
|
333
|
+
chol_cov_given_next=dummy_tree_like(filter_state.chol_cov)
|
|
334
|
+
if store_chol_cov_given_next
|
|
335
|
+
else None,
|
|
336
|
+
model_inputs=dummy_model_inputs,
|
|
337
|
+
)
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from cuthbert.gaussian.moments import (
|
|
2
|
+
associative_filter,
|
|
3
|
+
non_associative_filter,
|
|
4
|
+
smoother,
|
|
5
|
+
)
|
|
6
|
+
from cuthbert.gaussian.moments.filter import build_filter
|
|
7
|
+
from cuthbert.gaussian.moments.smoother import build_smoother
|
|
8
|
+
from cuthbert.gaussian.moments.types import (
|
|
9
|
+
GetDynamicsMoments,
|
|
10
|
+
GetObservationMoments,
|
|
11
|
+
)
|