cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. cuthbertlib/discrete/__init__.py +0 -0
  29. cuthbertlib/discrete/filtering.py +49 -0
  30. cuthbertlib/discrete/smoothing.py +35 -0
  31. cuthbertlib/kalman/__init__.py +4 -0
  32. cuthbertlib/kalman/filtering.py +213 -0
  33. cuthbertlib/kalman/generate.py +85 -0
  34. cuthbertlib/kalman/sampling.py +68 -0
  35. cuthbertlib/kalman/smoothing.py +121 -0
  36. cuthbertlib/linalg/__init__.py +7 -0
  37. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  38. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  39. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  40. cuthbertlib/linalg/tria.py +21 -0
  41. cuthbertlib/linearize/__init__.py +7 -0
  42. cuthbertlib/linearize/log_density.py +175 -0
  43. cuthbertlib/linearize/moments.py +94 -0
  44. cuthbertlib/linearize/taylor.py +83 -0
  45. cuthbertlib/quadrature/__init__.py +4 -0
  46. cuthbertlib/quadrature/common.py +102 -0
  47. cuthbertlib/quadrature/cubature.py +73 -0
  48. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  49. cuthbertlib/quadrature/linearize.py +143 -0
  50. cuthbertlib/quadrature/unscented.py +79 -0
  51. cuthbertlib/quadrature/utils.py +109 -0
  52. cuthbertlib/resampling/__init__.py +3 -0
  53. cuthbertlib/resampling/killing.py +79 -0
  54. cuthbertlib/resampling/multinomial.py +53 -0
  55. cuthbertlib/resampling/protocols.py +92 -0
  56. cuthbertlib/resampling/systematic.py +78 -0
  57. cuthbertlib/resampling/utils.py +82 -0
  58. cuthbertlib/smc/__init__.py +0 -0
  59. cuthbertlib/smc/ess.py +24 -0
  60. cuthbertlib/smc/smoothing/__init__.py +0 -0
  61. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  62. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  63. cuthbertlib/smc/smoothing/protocols.py +44 -0
  64. cuthbertlib/smc/smoothing/tracing.py +45 -0
  65. cuthbertlib/stats/__init__.py +0 -0
  66. cuthbertlib/stats/multivariate_normal.py +102 -0
  67. cuthbert-0.0.2.dist-info/RECORD +0 -12
  68. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  69. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
  70. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,180 @@
1
+ """Implements the associative linearized moments Kalman filter."""
2
+
3
+ from jax import eval_shape, tree
4
+ from jax import numpy as jnp
5
+
6
+ from cuthbert.gaussian.kalman import GetInitParams
7
+ from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
8
+ from cuthbert.gaussian.types import (
9
+ LinearizedKalmanFilterState,
10
+ )
11
+ from cuthbert.utils import dummy_tree_like
12
+ from cuthbertlib.kalman import filtering
13
+ from cuthbertlib.linearize import linearize_moments
14
+ from cuthbertlib.types import (
15
+ ArrayTreeLike,
16
+ KeyArray,
17
+ )
18
+
19
+
20
+ def init_prepare(
21
+ model_inputs: ArrayTreeLike,
22
+ get_init_params: GetInitParams,
23
+ get_observation_params: GetObservationMoments,
24
+ key: KeyArray | None = None,
25
+ ) -> LinearizedKalmanFilterState:
26
+ """Prepare the initial state for the linearized moments Kalman filter.
27
+
28
+ Args:
29
+ model_inputs: Model inputs.
30
+ get_init_params: Function to get m0, chol_P0 from model inputs.
31
+ get_observation_params: Function to get observation conditional mean,
32
+ (generalised) Cholesky covariance function, linearization point and
33
+ observation.
34
+ key: JAX random key - not used.
35
+
36
+ Returns:
37
+ State for the linearized moments Kalman filter.
38
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
39
+ and log_normalizing_constant.
40
+ """
41
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
42
+ m0, chol_P0 = get_init_params(model_inputs)
43
+
44
+ prior_state = LinearizedKalmanFilterState(
45
+ elem=filtering.FilterScanElement(
46
+ A=jnp.zeros_like(chol_P0),
47
+ b=m0,
48
+ U=chol_P0,
49
+ eta=jnp.zeros_like(m0),
50
+ Z=jnp.zeros_like(chol_P0),
51
+ ell=jnp.array(0.0),
52
+ ),
53
+ model_inputs=model_inputs,
54
+ mean_prev=dummy_tree_like(m0),
55
+ )
56
+
57
+ mean_and_chol_cov_func, linearization_point, y = get_observation_params(
58
+ prior_state, model_inputs
59
+ )
60
+
61
+ H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
62
+ (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
63
+
64
+ elem = filtering.FilterScanElement(
65
+ A=jnp.zeros_like(chol_P),
66
+ b=m,
67
+ U=chol_P,
68
+ eta=jnp.zeros_like(m),
69
+ Z=jnp.zeros_like(chol_P),
70
+ ell=ell,
71
+ )
72
+
73
+ return LinearizedKalmanFilterState(
74
+ elem=elem,
75
+ model_inputs=model_inputs,
76
+ mean_prev=dummy_tree_like(m),
77
+ )
78
+
79
+
80
+ def filter_prepare(
81
+ model_inputs: ArrayTreeLike,
82
+ get_init_params: GetInitParams,
83
+ get_dynamics_params: GetDynamicsMoments,
84
+ get_observation_params: GetObservationMoments,
85
+ key: KeyArray | None = None,
86
+ ) -> LinearizedKalmanFilterState:
87
+ """Prepare a state for a linearized moments Kalman filter step.
88
+
89
+ Just passes through model inputs.
90
+
91
+ `associative_scan` is supported but only accurate when `state` is ignored
92
+ in `get_dynamics_params` and `get_observation_params`.
93
+
94
+ Args:
95
+ model_inputs: Model inputs.
96
+ get_init_params: Function to get m0, chol_P0 from model inputs.
97
+ Only used to infer shape of mean and chol_cov.
98
+ get_dynamics_params: Function to get dynamics conditional mean and
99
+ (generalised) Cholesky covariance from linearization point and model inputs.
100
+ `associative_scan` only supported when `state` is ignored.
101
+ get_observation_params: Function to get observation conditional mean,
102
+ (generalised) Cholesky covariance and observation from linearization point
103
+ and model inputs.
104
+ `associative_scan` only supported when `state` is ignored.
105
+ key: JAX random key - not used.
106
+
107
+ Returns:
108
+ Prepared state for linearized moments Kalman filter.
109
+ """
110
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
111
+ dummy_mean_struct = eval_shape(lambda mi: get_init_params(mi)[0], model_inputs)
112
+ dummy_mean = dummy_tree_like(dummy_mean_struct)
113
+ dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
114
+
115
+ dummy_state = LinearizedKalmanFilterState(
116
+ elem=filtering.FilterScanElement(
117
+ A=dummy_chol_cov,
118
+ b=dummy_mean,
119
+ U=dummy_chol_cov,
120
+ eta=dummy_mean,
121
+ Z=dummy_chol_cov,
122
+ ell=jnp.array(0.0),
123
+ ),
124
+ model_inputs=model_inputs,
125
+ mean_prev=dummy_mean,
126
+ )
127
+
128
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
129
+ dummy_state, model_inputs
130
+ )
131
+ F, c, chol_Q = linearize_moments(
132
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point
133
+ )
134
+
135
+ observation_mean_and_chol_cov_func, observation_linearization_point, y = (
136
+ get_observation_params(dummy_state, model_inputs)
137
+ )
138
+ H, d, chol_R = linearize_moments(
139
+ observation_mean_and_chol_cov_func, observation_linearization_point
140
+ )
141
+
142
+ elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, y)
143
+
144
+ return LinearizedKalmanFilterState(
145
+ elem=elem,
146
+ model_inputs=model_inputs,
147
+ mean_prev=dummy_mean,
148
+ )
149
+
150
+
151
+ def filter_combine(
152
+ state_1: LinearizedKalmanFilterState,
153
+ state_2: LinearizedKalmanFilterState,
154
+ ) -> LinearizedKalmanFilterState:
155
+ """Combine previous filter state with state prepared with latest model inputs.
156
+
157
+ `associative_scan` is supported but only accurate when `state` is ignored
158
+ in `get_dynamics_params` and `get_observation_params`.
159
+
160
+ Applies standard associative Kalman filtering operator since dynamics and observation
161
+ parameters are extracted in filter_prepare.
162
+
163
+ Args:
164
+ state_1: State from previous time step.
165
+ state_2: State prepared (only access model_inputs attribute).
166
+
167
+ Returns:
168
+ Predicted and updated linearized moments Kalman filter state.
169
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
170
+ and log_normalizing_constant.
171
+ """
172
+ combined_elem = filtering.filtering_operator(
173
+ state_1.elem,
174
+ state_2.elem,
175
+ )
176
+ return LinearizedKalmanFilterState(
177
+ elem=combined_elem,
178
+ model_inputs=state_2.model_inputs,
179
+ mean_prev=state_1.mean,
180
+ )
@@ -0,0 +1,95 @@
1
+ r"""Linearized moments Kalman filter.
2
+
3
+ Takes a user provided conditional `mean` and `chol_cov` functions to define a
4
+ conditionally linear Gaussian state space model.
5
+
6
+ I.e., we approximate conditional densities as
7
+
8
+ $$
9
+ p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top).
10
+ $$
11
+
12
+ See `cuthbertlib.linearize` for more details.
13
+
14
+ Parallelism via `associative_scan` is supported, but requires the `state` argument
15
+ to be ignored in `get_dynamics_params` and `get_observation_params`.
16
+ I.e. the linearization points are pre-defined or extracted from model inputs.
17
+ """
18
+
19
+ from functools import partial
20
+
21
+ from cuthbert.gaussian.moments import associative_filter, non_associative_filter
22
+ from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
23
+ from cuthbert.gaussian.types import GetInitParams
24
+ from cuthbert.inference import Filter
25
+
26
+
27
+ def build_filter(
28
+ get_init_params: GetInitParams,
29
+ get_dynamics_params: GetDynamicsMoments,
30
+ get_observation_params: GetObservationMoments,
31
+ associative: bool = False,
32
+ ) -> Filter:
33
+ """Build linearized moments Kalman inference filter.
34
+
35
+ If `associative` is True all filtering linearization points are pre-defined or
36
+ extracted from model inputs. The `state` argument should be ignored in
37
+ `get_dynamics_params` and `get_observation_params`.
38
+
39
+ If `associative` is False the linearization points can be extracted from the
40
+ previous filter state for dynamics parameters and the predict state for
41
+ observation parameters.
42
+
43
+ Args:
44
+ get_init_params: Function to get m0, chol_P0 from model inputs.
45
+ get_dynamics_params: Function to get dynamics conditional mean and
46
+ (generalised) Cholesky covariance from linearization point and model inputs.
47
+ and linearization points (for the previous and current time points)
48
+ If `associative` is True, the `state` argument should be ignored.
49
+ get_observation_params: Function to get observation conditional mean,
50
+ (generalised) Cholesky covariance and observation from linearization point
51
+ and model inputs.
52
+ If `associative` is True, the `state` argument should be ignored.
53
+ associative: If True, then the filter is suitable for associative scan, but
54
+ assumes that the `state` is ignored in `get_dynamics_params` and
55
+ `get_observation_params`.
56
+ If False, then the filter is suitable for non-associative scan, but
57
+ the user is free to use the `state` to extract the linearization points.
58
+
59
+ Returns:
60
+ Linearized moments Kalman filter object.
61
+ """
62
+ if associative:
63
+ return Filter(
64
+ init_prepare=partial(
65
+ associative_filter.init_prepare,
66
+ get_init_params=get_init_params,
67
+ get_observation_params=get_observation_params,
68
+ ),
69
+ filter_prepare=partial(
70
+ associative_filter.filter_prepare,
71
+ get_init_params=get_init_params,
72
+ get_dynamics_params=get_dynamics_params,
73
+ get_observation_params=get_observation_params,
74
+ ),
75
+ filter_combine=associative_filter.filter_combine,
76
+ associative=True,
77
+ )
78
+ else:
79
+ return Filter(
80
+ init_prepare=partial(
81
+ non_associative_filter.init_prepare,
82
+ get_init_params=get_init_params,
83
+ get_observation_params=get_observation_params,
84
+ ),
85
+ filter_prepare=partial(
86
+ non_associative_filter.filter_prepare,
87
+ get_init_params=get_init_params,
88
+ ),
89
+ filter_combine=partial(
90
+ non_associative_filter.filter_combine,
91
+ get_dynamics_params=get_dynamics_params,
92
+ get_observation_params=get_observation_params,
93
+ ),
94
+ associative=False,
95
+ )
@@ -0,0 +1,161 @@
1
+ """Implements the non-associative linearized moments Kalman filter."""
2
+
3
+ from jax import eval_shape, tree
4
+ from jax import numpy as jnp
5
+
6
+ from cuthbert.gaussian.moments.types import GetDynamicsMoments, GetObservationMoments
7
+ from cuthbert.gaussian.types import GetInitParams, LinearizedKalmanFilterState
8
+ from cuthbert.gaussian.utils import linearized_kalman_filter_state_dummy_elem
9
+ from cuthbert.utils import dummy_tree_like
10
+ from cuthbertlib.kalman import filtering
11
+ from cuthbertlib.linearize import linearize_moments
12
+ from cuthbertlib.types import ArrayTreeLike, KeyArray
13
+
14
+
15
+ def init_prepare(
16
+ model_inputs: ArrayTreeLike,
17
+ get_init_params: GetInitParams,
18
+ get_observation_params: GetObservationMoments,
19
+ key: KeyArray | None = None,
20
+ ) -> LinearizedKalmanFilterState:
21
+ """Prepare the initial state for the linearized moments Kalman filter.
22
+
23
+ Args:
24
+ model_inputs: Model inputs.
25
+ get_init_params: Function to get m0, chol_P0 from model inputs.
26
+ get_observation_params: Function to get observation conditional mean,
27
+ (generalised) Cholesky covariance function, linearization point and
28
+ observation.
29
+ key: JAX random key - not used.
30
+
31
+ Returns:
32
+ State for the linearized moments Kalman filter.
33
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
34
+ and log_normalizing_constant.
35
+ """
36
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
37
+ m0, chol_P0 = get_init_params(model_inputs)
38
+
39
+ prior_state = LinearizedKalmanFilterState(
40
+ elem=filtering.FilterScanElement(
41
+ A=jnp.zeros_like(chol_P0),
42
+ b=m0,
43
+ U=chol_P0,
44
+ eta=jnp.zeros_like(m0),
45
+ Z=jnp.zeros_like(chol_P0),
46
+ ell=jnp.array(0.0),
47
+ ),
48
+ model_inputs=model_inputs,
49
+ mean_prev=dummy_tree_like(m0),
50
+ )
51
+
52
+ mean_and_chol_cov_func, linearization_point, y = get_observation_params(
53
+ prior_state, model_inputs
54
+ )
55
+
56
+ H, d, chol_R = linearize_moments(mean_and_chol_cov_func, linearization_point)
57
+ (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
58
+ return linearized_kalman_filter_state_dummy_elem(
59
+ mean=m,
60
+ chol_cov=chol_P,
61
+ log_normalizing_constant=ell,
62
+ model_inputs=model_inputs,
63
+ mean_prev=dummy_tree_like(m),
64
+ )
65
+
66
+
67
+ def filter_prepare(
68
+ model_inputs: ArrayTreeLike,
69
+ get_init_params: GetInitParams,
70
+ key: KeyArray | None = None,
71
+ ) -> LinearizedKalmanFilterState:
72
+ """Prepare a state for a linearized moments Kalman filter step.
73
+
74
+ Just passes through model inputs.
75
+
76
+ Args:
77
+ model_inputs: Model inputs.
78
+ get_init_params: Function to get m0, chol_P0 from model inputs.
79
+ Only used to infer shape of mean and chol_cov.
80
+ key: JAX random key - not used.
81
+
82
+ Returns:
83
+ Prepared state for linearized moments Kalman filter.
84
+ """
85
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
86
+ dummy_mean_struct = eval_shape(lambda mi: get_init_params(mi)[0], model_inputs)
87
+ dummy_mean = dummy_tree_like(dummy_mean_struct)
88
+ dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
89
+
90
+ return linearized_kalman_filter_state_dummy_elem(
91
+ mean=dummy_mean,
92
+ chol_cov=dummy_chol_cov,
93
+ log_normalizing_constant=jnp.array(0.0),
94
+ model_inputs=model_inputs,
95
+ mean_prev=dummy_mean,
96
+ )
97
+
98
+
99
+ def filter_combine(
100
+ state_1: LinearizedKalmanFilterState,
101
+ state_2: LinearizedKalmanFilterState,
102
+ get_dynamics_params: GetDynamicsMoments,
103
+ get_observation_params: GetObservationMoments,
104
+ ) -> LinearizedKalmanFilterState:
105
+ """Combine previous filter state with state prepared with latest model inputs.
106
+
107
+ Applies linearized moments Kalman predict + filter update in covariance square
108
+ root form.
109
+ Not suitable for associative scan.
110
+
111
+ Args:
112
+ state_1: State from previous time step.
113
+ state_2: State prepared (only access model_inputs attribute).
114
+ get_dynamics_params: Function to get dynamics conditional mean and
115
+ (generalised) Cholesky covariance from linearization point and model inputs.
116
+ get_observation_params: Function to get observation conditional mean,
117
+ (generalised) Cholesky covariance and observation from linearization point
118
+ and model inputs.
119
+
120
+ Returns:
121
+ Predicted and updated linearized moments Kalman filter state.
122
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
123
+ and log_normalizing_constant.
124
+ """
125
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
126
+ state_1, state_2.model_inputs
127
+ )
128
+ F, c, chol_Q = linearize_moments(
129
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point
130
+ )
131
+
132
+ predict_mean, predict_chol_cov = filtering.predict(
133
+ state_1.mean, state_1.chol_cov, F, c, chol_Q
134
+ )
135
+ predict_state = linearized_kalman_filter_state_dummy_elem(
136
+ mean=predict_mean,
137
+ chol_cov=predict_chol_cov,
138
+ log_normalizing_constant=state_1.log_normalizing_constant,
139
+ model_inputs=state_2.model_inputs,
140
+ mean_prev=state_1.mean,
141
+ )
142
+
143
+ observation_mean_and_chol_cov_func, observation_linearization_point, y = (
144
+ get_observation_params(predict_state, state_2.model_inputs)
145
+ )
146
+ H, d, chol_R = linearize_moments(
147
+ observation_mean_and_chol_cov_func, observation_linearization_point
148
+ )
149
+
150
+ (update_mean, update_chol_cov), log_normalizing_constant = filtering.update(
151
+ predict_mean, predict_chol_cov, H, d, chol_R, y
152
+ )
153
+
154
+ return linearized_kalman_filter_state_dummy_elem(
155
+ mean=update_mean,
156
+ chol_cov=update_chol_cov,
157
+ log_normalizing_constant=state_1.log_normalizing_constant
158
+ + log_normalizing_constant,
159
+ model_inputs=state_2.model_inputs,
160
+ mean_prev=state_1.mean,
161
+ )
@@ -0,0 +1,118 @@
1
+ r"""Linearized moments Kalman smoother.
2
+
3
+ Takes a user provided conditional `mean` and `chol_cov` functions to define a
4
+ conditionally linear Gaussian state space model.
5
+
6
+ I.e., we approximate conditional densities as
7
+
8
+ $$
9
+ p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top).
10
+ $$
11
+
12
+ See `cuthbertlib.linearize` for more details.
13
+
14
+ Parallelism via `associative_scan` is supported, but requires the `state` argument
15
+ to be ignored in `get_dynamics_params`.
16
+ I.e. the linearization points are pre-defined or extracted from model inputs.
17
+ """
18
+
19
+ from functools import partial
20
+
21
+ from jax import numpy as jnp
22
+ from jax import tree
23
+
24
+ from cuthbert.gaussian.kalman import (
25
+ KalmanSmootherState,
26
+ convert_filter_to_smoother_state,
27
+ smoother_combine,
28
+ )
29
+ from cuthbert.gaussian.moments.types import GetDynamicsMoments
30
+ from cuthbert.gaussian.types import LinearizedKalmanFilterState
31
+ from cuthbert.inference import Smoother
32
+ from cuthbertlib.kalman import smoothing
33
+ from cuthbertlib.linearize import linearize_moments
34
+ from cuthbertlib.types import ArrayTreeLike, KeyArray
35
+
36
+
37
+ def build_smoother(
38
+ get_dynamics_params: GetDynamicsMoments,
39
+ store_gain: bool = False,
40
+ store_chol_cov_given_next: bool = False,
41
+ ) -> Smoother:
42
+ """Build linearized moments Kalman inference smoother for conditionally Gaussian SSMs.
43
+
44
+ Args:
45
+ get_dynamics_params: Function to get dynamics conditional mean and
46
+ (generalised) Cholesky covariance from linearization point and model inputs.
47
+ store_gain: Whether to store the gain matrix in the smoother state.
48
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
49
+ in the smoother state.
50
+
51
+ Returns:
52
+ Linearized moments Kalman smoother object, suitable for associative scan.
53
+ """
54
+ return Smoother(
55
+ smoother_prepare=partial(
56
+ smoother_prepare,
57
+ get_dynamics_params=get_dynamics_params,
58
+ store_gain=store_gain,
59
+ store_chol_cov_given_next=store_chol_cov_given_next,
60
+ ),
61
+ smoother_combine=smoother_combine,
62
+ convert_filter_to_smoother_state=partial(
63
+ convert_filter_to_smoother_state,
64
+ store_gain=store_gain,
65
+ store_chol_cov_given_next=store_chol_cov_given_next,
66
+ ),
67
+ associative=True,
68
+ )
69
+
70
+
71
+ def smoother_prepare(
72
+ filter_state: LinearizedKalmanFilterState,
73
+ get_dynamics_params: GetDynamicsMoments,
74
+ model_inputs: ArrayTreeLike,
75
+ store_gain: bool = False,
76
+ store_chol_cov_given_next: bool = False,
77
+ key: KeyArray | None = None,
78
+ ) -> KalmanSmootherState:
79
+ """Prepare a state for an extended Kalman smoother step.
80
+
81
+ Note that the model_inputs here are different to filter_state.model_inputs.
82
+ The model_inputs required here are for the transition from t to t+1.
83
+ filter_state.model_inputs represents the transition from t-1 to t.
84
+
85
+ Args:
86
+ filter_state: State generated by the extended Kalman filter at time t.
87
+ get_dynamics_params: Function to get dynamics conditional mean and
88
+ (generalised) Cholesky covariance from linearization point and model inputs.
89
+ model_inputs: Model inputs for the transition from t to t+1.
90
+ store_gain: Whether to store the gain matrix in the smoother state.
91
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
92
+ in the smoother state.
93
+ key: JAX random key - not used.
94
+
95
+ Returns:
96
+ Prepared state for the Kalman smoother.
97
+ """
98
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
99
+ filter_mean = filter_state.mean
100
+ filter_chol_cov = filter_state.chol_cov
101
+
102
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point = get_dynamics_params(
103
+ filter_state, model_inputs
104
+ )
105
+
106
+ F, c, chol_Q = linearize_moments(
107
+ dynamics_mean_and_chol_cov_func, dynamics_linearization_point
108
+ )
109
+
110
+ state = smoothing.associative_params_single(
111
+ filter_mean, filter_chol_cov, F, c, chol_Q
112
+ )
113
+ return KalmanSmootherState(
114
+ elem=state,
115
+ gain=state.E if store_gain else None,
116
+ chol_cov_given_next=state.D if store_chol_cov_given_next else None,
117
+ model_inputs=model_inputs,
118
+ )
@@ -0,0 +1,51 @@
1
+ """Provides types for the moment-based linearization of Gaussian state-space models."""
2
+
3
+ from typing import Protocol
4
+
5
+ from cuthbert.gaussian.types import LinearizedKalmanFilterState
6
+ from cuthbertlib.linearize.moments import MeanAndCholCovFunc
7
+ from cuthbertlib.types import Array, ArrayTreeLike
8
+
9
+
10
+ class GetDynamicsMoments(Protocol):
11
+ """Protocol for extracting the dynamics specifications."""
12
+
13
+ def __call__(
14
+ self,
15
+ state: LinearizedKalmanFilterState,
16
+ model_inputs: ArrayTreeLike,
17
+ ) -> tuple[MeanAndCholCovFunc, Array]:
18
+ """Get dynamics conditional mean and chol_cov function and linearization point.
19
+
20
+ `associative_scan` only supported when `state` is ignored.
21
+
22
+ Args:
23
+ state: NamedTuple containing `mean` and `mean_prev` attributes.
24
+ model_inputs: Model inputs.
25
+
26
+ Returns:
27
+ Tuple with dynamics conditional mean and (generalised) Cholesky covariance
28
+ function and linearization point.
29
+ """
30
+ ...
31
+
32
+
33
+ class GetObservationMoments(Protocol):
34
+ """Protocol for extracting the observation specifications."""
35
+
36
+ def __call__(
37
+ self, state: LinearizedKalmanFilterState, model_inputs: ArrayTreeLike
38
+ ) -> tuple[MeanAndCholCovFunc, Array, Array]:
39
+ """Get conditional mean and chol_cov function, linearization point and observation.
40
+
41
+ `associative_scan` only supported when `state` input is ignored.
42
+
43
+ Args:
44
+ state: NamedTuple containing `mean` and `mean_prev` attributes.
45
+ model_inputs: Model inputs.
46
+
47
+ Returns:
48
+ Tuple with conditional mean and chol_cov function, linearization point
49
+ and observation.
50
+ """
51
+ ...
@@ -0,0 +1,14 @@
1
+ from cuthbert.gaussian.taylor import (
2
+ associative_filter,
3
+ non_associative_filter,
4
+ smoother,
5
+ )
6
+ from cuthbert.gaussian.taylor.filter import build_filter
7
+ from cuthbert.gaussian.taylor.smoother import build_smoother
8
+ from cuthbert.gaussian.taylor.types import (
9
+ GetDynamicsLogDensity,
10
+ GetInitLogDensity,
11
+ GetObservationFunc,
12
+ LogPotential,
13
+ )
14
+ from cuthbert.gaussian.types import LinearizedKalmanFilterState