cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
  29. cuthbertlib/discrete/__init__.py +0 -0
  30. cuthbertlib/discrete/filtering.py +49 -0
  31. cuthbertlib/discrete/smoothing.py +35 -0
  32. cuthbertlib/kalman/__init__.py +4 -0
  33. cuthbertlib/kalman/filtering.py +213 -0
  34. cuthbertlib/kalman/generate.py +85 -0
  35. cuthbertlib/kalman/sampling.py +68 -0
  36. cuthbertlib/kalman/smoothing.py +121 -0
  37. cuthbertlib/linalg/__init__.py +7 -0
  38. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  39. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  40. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  41. cuthbertlib/linalg/tria.py +21 -0
  42. cuthbertlib/linearize/__init__.py +7 -0
  43. cuthbertlib/linearize/log_density.py +175 -0
  44. cuthbertlib/linearize/moments.py +94 -0
  45. cuthbertlib/linearize/taylor.py +83 -0
  46. cuthbertlib/quadrature/__init__.py +4 -0
  47. cuthbertlib/quadrature/common.py +102 -0
  48. cuthbertlib/quadrature/cubature.py +73 -0
  49. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  50. cuthbertlib/quadrature/linearize.py +143 -0
  51. cuthbertlib/quadrature/unscented.py +79 -0
  52. cuthbertlib/quadrature/utils.py +109 -0
  53. cuthbertlib/resampling/__init__.py +3 -0
  54. cuthbertlib/resampling/killing.py +79 -0
  55. cuthbertlib/resampling/multinomial.py +53 -0
  56. cuthbertlib/resampling/protocols.py +92 -0
  57. cuthbertlib/resampling/systematic.py +78 -0
  58. cuthbertlib/resampling/utils.py +82 -0
  59. cuthbertlib/smc/__init__.py +0 -0
  60. cuthbertlib/smc/ess.py +24 -0
  61. cuthbertlib/smc/smoothing/__init__.py +0 -0
  62. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  63. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  64. cuthbertlib/smc/smoothing/protocols.py +44 -0
  65. cuthbertlib/smc/smoothing/tracing.py +45 -0
  66. cuthbertlib/stats/__init__.py +0 -0
  67. cuthbertlib/stats/multivariate_normal.py +102 -0
  68. cuthbert-0.0.1.dist-info/RECORD +0 -12
  69. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  70. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,222 @@
1
+ """Implements the associative linearized Taylor Kalman filter."""
2
+
3
+ from jax import eval_shape, tree
4
+ from jax import numpy as jnp
5
+
6
+ from cuthbert.gaussian.taylor.non_associative_filter import process_observation
7
+ from cuthbert.gaussian.taylor.types import (
8
+ GetDynamicsLogDensity,
9
+ GetInitLogDensity,
10
+ GetObservationFunc,
11
+ )
12
+ from cuthbert.gaussian.types import (
13
+ LinearizedKalmanFilterState,
14
+ )
15
+ from cuthbert.utils import dummy_tree_like
16
+ from cuthbertlib.kalman import filtering
17
+ from cuthbertlib.linearize import linearize_log_density
18
+ from cuthbertlib.types import (
19
+ ArrayTreeLike,
20
+ KeyArray,
21
+ )
22
+
23
+
24
+ def init_prepare(
25
+ model_inputs: ArrayTreeLike,
26
+ get_init_log_density: GetInitLogDensity,
27
+ get_observation_func: GetObservationFunc,
28
+ rtol: float | None = None,
29
+ ignore_nan_dims: bool = False,
30
+ key: KeyArray | None = None,
31
+ ) -> LinearizedKalmanFilterState:
32
+ """Prepare the initial state for the linearized Taylor Kalman filter.
33
+
34
+ Args:
35
+ model_inputs: Model inputs.
36
+ get_init_log_density: Function that returns log density log p(x_0)
37
+ and linearization point.
38
+ get_observation_func: Function that returns either
39
+ - An observation log density
40
+ function log p(y_0 | x_0) as well as points x_0 and y_0
41
+ to linearize around.
42
+ - A log potential function log G(x_0) and a linearization point x_0.
43
+ rtol: The relative tolerance for the singular values of precision matrices
44
+ when passed to `symmetric_inv_sqrt` during linearization.
45
+ Cutoff for small singular values; singular values smaller than
46
+ `rtol * largest_singular_value` are treated as zero.
47
+ The default is determined based on the floating point precision of the dtype.
48
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
49
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
50
+ precision matrices (found via linearization) as missing and ignore all rows
51
+ and columns associated with them.
52
+ key: JAX random key - not used.
53
+
54
+ Returns:
55
+ State for the linearized Taylor Kalman filter.
56
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
57
+ and log_normalizing_constant.
58
+ """
59
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
60
+ init_log_density, linearization_point = get_init_log_density(model_inputs)
61
+
62
+ _, m0, chol_P0 = linearize_log_density(
63
+ lambda _, x: init_log_density(x),
64
+ linearization_point,
65
+ linearization_point,
66
+ rtol=rtol,
67
+ ignore_nan_dims=ignore_nan_dims,
68
+ )
69
+
70
+ prior_state = LinearizedKalmanFilterState(
71
+ elem=filtering.FilterScanElement(
72
+ A=jnp.zeros_like(chol_P0),
73
+ b=m0,
74
+ U=chol_P0,
75
+ eta=jnp.zeros_like(m0),
76
+ Z=jnp.zeros_like(chol_P0),
77
+ ell=jnp.array(0.0),
78
+ ),
79
+ model_inputs=model_inputs,
80
+ mean_prev=dummy_tree_like(m0),
81
+ )
82
+
83
+ observation_output = get_observation_func(prior_state, model_inputs)
84
+ H, d, chol_R, observation = process_observation(
85
+ observation_output,
86
+ rtol=rtol,
87
+ ignore_nan_dims=ignore_nan_dims,
88
+ )
89
+
90
+ (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, observation)
91
+
92
+ elem = filtering.FilterScanElement(
93
+ A=jnp.zeros_like(chol_P),
94
+ b=m,
95
+ U=chol_P,
96
+ eta=jnp.zeros_like(m),
97
+ Z=jnp.zeros_like(chol_P),
98
+ ell=ell,
99
+ )
100
+
101
+ return LinearizedKalmanFilterState(
102
+ elem=elem,
103
+ model_inputs=model_inputs,
104
+ mean_prev=dummy_tree_like(m),
105
+ )
106
+
107
+
108
+ def filter_prepare(
109
+ model_inputs: ArrayTreeLike,
110
+ get_init_log_density: GetInitLogDensity,
111
+ get_dynamics_log_density: GetDynamicsLogDensity,
112
+ get_observation_func: GetObservationFunc,
113
+ rtol: float | None = None,
114
+ ignore_nan_dims: bool = False,
115
+ key: KeyArray | None = None,
116
+ ) -> LinearizedKalmanFilterState:
117
+ """Prepare a state for a linearized Taylor Kalman filter step.
118
+
119
+ `associative_scan` is supported but only accurate when `state` is ignored
120
+ in `get_dynamics_log_density` and `get_observation_func`.
121
+
122
+ Args:
123
+ model_inputs: Model inputs.
124
+ get_init_log_density: Function that returns log density log p(x_0)
125
+ and linearization point. Only used to infer shape of mean and chol_cov.
126
+ get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
127
+ and linearization points (for the previous and current time points)
128
+ `associative_scan` only supported when `state` is ignored.
129
+ get_observation_func: Function to get observation function (either conditional
130
+ log density or log potential), linearization point and optional observation
131
+ (not required for log potential functions).
132
+ `associative_scan` only supported when `state` is ignored.
133
+ rtol: The relative tolerance for the singular values of precision matrices
134
+ when passed to `symmetric_inv_sqrt` during linearization.
135
+ Cutoff for small singular values; singular values smaller than
136
+ `rtol * largest_singular_value` are treated as zero.
137
+ The default is determined based on the floating point precision of the dtype.
138
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
139
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
140
+ precision matrices (found via linearization) as missing and ignore all rows
141
+ and columns associated with them.
142
+ key: JAX random key - not used.
143
+
144
+ Returns:
145
+ Prepared state for linearized Taylor Kalman filter.
146
+ """
147
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
148
+ dummy_mean_struct = eval_shape(lambda mi: get_init_log_density(mi)[1], model_inputs)
149
+ dummy_mean = dummy_tree_like(dummy_mean_struct)
150
+ dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
151
+
152
+ dummy_state = LinearizedKalmanFilterState(
153
+ elem=filtering.FilterScanElement(
154
+ A=dummy_chol_cov,
155
+ b=dummy_mean,
156
+ U=dummy_chol_cov,
157
+ eta=dummy_mean,
158
+ Z=dummy_chol_cov,
159
+ ell=jnp.array(0.0),
160
+ ),
161
+ model_inputs=model_inputs,
162
+ mean_prev=dummy_mean,
163
+ )
164
+
165
+ log_dynamics_density, linearization_point_prev, linearization_point_curr = (
166
+ get_dynamics_log_density(dummy_state, model_inputs)
167
+ )
168
+
169
+ F, c, chol_Q = linearize_log_density(
170
+ log_dynamics_density,
171
+ linearization_point_prev,
172
+ linearization_point_curr,
173
+ rtol=rtol,
174
+ ignore_nan_dims=ignore_nan_dims,
175
+ )
176
+
177
+ observation_output = get_observation_func(dummy_state, model_inputs)
178
+ H, d, chol_R, observation = process_observation(
179
+ observation_output,
180
+ rtol=rtol,
181
+ ignore_nan_dims=ignore_nan_dims,
182
+ )
183
+
184
+ elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, observation)
185
+
186
+ return LinearizedKalmanFilterState(
187
+ elem=elem,
188
+ model_inputs=model_inputs,
189
+ mean_prev=dummy_mean,
190
+ )
191
+
192
+
193
+ def filter_combine(
194
+ state_1: LinearizedKalmanFilterState,
195
+ state_2: LinearizedKalmanFilterState,
196
+ ) -> LinearizedKalmanFilterState:
197
+ """Combine previous filter state with state prepared with latest model inputs.
198
+
199
+ `associative_scan` is supported but only accurate when `state` is ignored
200
+ in `get_dynamics_log_density` and `get_observation_func`.
201
+
202
+ Applies standard associative Kalman filtering operator since dynamics and observation
203
+ parameters are extracted in filter_prepare.
204
+
205
+ Args:
206
+ state_1: State from previous time step.
207
+ state_2: State prepared (only access model_inputs attribute).
208
+
209
+ Returns:
210
+ Predicted and updated linearized Taylor Kalman filter state.
211
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
212
+ and log_normalizing_constant.
213
+ """
214
+ combined_elem = filtering.filtering_operator(
215
+ state_1.elem,
216
+ state_2.elem,
217
+ )
218
+ return LinearizedKalmanFilterState(
219
+ elem=combined_elem,
220
+ model_inputs=state_2.model_inputs,
221
+ mean_prev=state_1.mean,
222
+ )
@@ -0,0 +1,129 @@
1
+ r"""Linearized Taylor Kalman filter.
2
+
3
+ Uses automatic differentiation to extract conditionally Gaussian parameters from log
4
+ densities of the dynamics and observation distributions.
5
+
6
+ This differs from `gaussian/moments`, which requires `mean` and `chol_cov`
7
+ functions as input rather than log densities.
8
+
9
+ I.e., we approximate conditional densities as
10
+
11
+ $$
12
+ p(y \mid x) \approx N(y \mid H x + d, L L^T),
13
+ $$
14
+
15
+ and potentials as
16
+
17
+ $$
18
+ G(x) \approx N(x \mid m, L L^T),
19
+ $$
20
+
21
+ where $L$ is the Cholesky factor of the covariance matrix.
22
+
23
+ See `cuthbertlib.linearize` for more details.
24
+
25
+ Parallelism via `associative_scan` is supported, but requires the `state` argument
26
+ to be ignored in `get_dynamics_log_density` and `get_observation_func`.
27
+ I.e. the linearization points are pre-defined or extracted from model inputs.
28
+ """
29
+
30
+ from functools import partial
31
+
32
+ from cuthbert.gaussian.taylor import associative_filter, non_associative_filter
33
+ from cuthbert.gaussian.taylor.types import (
34
+ GetDynamicsLogDensity,
35
+ GetInitLogDensity,
36
+ GetObservationFunc,
37
+ )
38
+ from cuthbert.inference import Filter
39
+
40
+
41
+ def build_filter(
42
+ get_init_log_density: GetInitLogDensity,
43
+ get_dynamics_log_density: GetDynamicsLogDensity,
44
+ get_observation_func: GetObservationFunc,
45
+ associative: bool = False,
46
+ rtol: float | None = None,
47
+ ignore_nan_dims: bool = False,
48
+ ) -> Filter:
49
+ """Build linearized Taylor Kalman inference filter.
50
+
51
+ If `associative` is True all filtering linearization points are pre-defined or
52
+ extracted from model inputs. The `state` argument should be ignored in
53
+ `get_dynamics_log_density` and `get_observation_func`.
54
+
55
+ If `associative` is False the linearization points can be extracted from the
56
+ previous filter state for dynamics parameters and the predict state for
57
+ observation parameters.
58
+
59
+ Args:
60
+ get_init_log_density: Function to get log density log p(x_0)
61
+ and linearization point.
62
+ Only takes `model_inputs` as input.
63
+ get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
64
+ and linearization points (for the previous and current time points)
65
+ If `associative` is True, the `state` argument should be ignored.
66
+ get_observation_func: Function to get observation function (either conditional
67
+ log density or log potential), linearization point and optional observation
68
+ (not required for log potential functions).
69
+ If `associative` is True, the `state` argument should be ignored.
70
+ associative: If True, then the filter is suitable for associative scan, but
71
+ assumes that the `state` is ignored in `get_dynamics_log_density` and
72
+ `get_observation_func`.
73
+ If False, then the filter is suitable for non-associative scan, but
74
+ the user is free to use the `state` to extract the linearization points.
75
+ rtol: The relative tolerance for the singular values of precision matrices
76
+ when passed to `symmetric_inv_sqrt` during linearization.
77
+ Cutoff for small singular values; singular values smaller than
78
+ `rtol * largest_singular_value` are treated as zero.
79
+ The default is determined based on the floating point precision of the dtype.
80
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
81
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
82
+ precision matrices (found via linearization) as missing and ignore all rows
83
+ and columns associated with them.
84
+
85
+ Returns:
86
+ Linearized Taylor Kalman filter object.
87
+ """
88
+ if associative:
89
+ return Filter(
90
+ init_prepare=partial(
91
+ associative_filter.init_prepare,
92
+ get_init_log_density=get_init_log_density,
93
+ get_observation_func=get_observation_func,
94
+ rtol=rtol,
95
+ ignore_nan_dims=ignore_nan_dims,
96
+ ),
97
+ filter_prepare=partial(
98
+ associative_filter.filter_prepare,
99
+ get_init_log_density=get_init_log_density,
100
+ get_dynamics_log_density=get_dynamics_log_density,
101
+ get_observation_func=get_observation_func,
102
+ rtol=rtol,
103
+ ignore_nan_dims=ignore_nan_dims,
104
+ ),
105
+ filter_combine=associative_filter.filter_combine,
106
+ associative=True,
107
+ )
108
+ else:
109
+ return Filter(
110
+ init_prepare=partial(
111
+ non_associative_filter.init_prepare,
112
+ get_init_log_density=get_init_log_density,
113
+ get_observation_func=get_observation_func,
114
+ rtol=rtol,
115
+ ignore_nan_dims=ignore_nan_dims,
116
+ ),
117
+ filter_prepare=partial(
118
+ non_associative_filter.filter_prepare,
119
+ get_init_log_density=get_init_log_density,
120
+ ),
121
+ filter_combine=partial(
122
+ non_associative_filter.filter_combine,
123
+ get_dynamics_log_density=get_dynamics_log_density,
124
+ get_observation_func=get_observation_func,
125
+ rtol=rtol,
126
+ ignore_nan_dims=ignore_nan_dims,
127
+ ),
128
+ associative=False,
129
+ )
@@ -0,0 +1,246 @@
1
+ """Implements the non-associative linearized Taylor Kalman filter."""
2
+
3
+ from jax import eval_shape, tree
4
+ from jax import numpy as jnp
5
+
6
+ from cuthbert.gaussian.taylor.types import (
7
+ GetDynamicsLogDensity,
8
+ GetInitLogDensity,
9
+ GetObservationFunc,
10
+ LogConditionalDensity,
11
+ LogPotential,
12
+ )
13
+ from cuthbert.gaussian.types import LinearizedKalmanFilterState
14
+ from cuthbert.gaussian.utils import linearized_kalman_filter_state_dummy_elem
15
+ from cuthbert.utils import dummy_tree_like
16
+ from cuthbertlib.kalman import filtering
17
+ from cuthbertlib.linearize import linearize_log_density, linearize_taylor
18
+ from cuthbertlib.types import (
19
+ Array,
20
+ ArrayTreeLike,
21
+ KeyArray,
22
+ )
23
+
24
+
25
+ def process_observation(
26
+ observation_output: tuple[LogConditionalDensity, Array, Array]
27
+ | tuple[LogPotential, Array],
28
+ rtol: float | None = None,
29
+ ignore_nan_dims: bool = False,
30
+ ) -> tuple[Array, Array, Array, Array]:
31
+ """Process observation for linearized Taylor Kalman filter."""
32
+ if len(observation_output) == 3:
33
+ observation_cond_log_density, linearization_point, observation = (
34
+ observation_output
35
+ )
36
+ H, d, chol_R = linearize_log_density(
37
+ observation_cond_log_density,
38
+ linearization_point,
39
+ observation,
40
+ rtol=rtol,
41
+ ignore_nan_dims=ignore_nan_dims,
42
+ )
43
+ else:
44
+ observation_log_potential, linearization_point = observation_output
45
+ d, chol_R = linearize_taylor(
46
+ observation_log_potential,
47
+ linearization_point,
48
+ rtol=rtol,
49
+ ignore_nan_dims=ignore_nan_dims,
50
+ )
51
+ # dummy mat and observation as potential is unconditional
52
+ # Note the minus sign as linear potential is -0.5 (x - d)^T (R R^T)^{-1} (x - d)
53
+ # and kalman expects -0.5 (y - H @ x - d)^T (R R^T)^{-1} (y - H @ x - d)
54
+ H = -jnp.eye(d.shape[0])
55
+ observation = jnp.where(
56
+ jnp.isnan(jnp.diag(chol_R)) * ignore_nan_dims, jnp.nan, 0.0
57
+ ) # Tell the cuthbertlib.kalman to skip these dimensions
58
+ return H, d, chol_R, observation
59
+
60
+
61
+ def init_prepare(
62
+ model_inputs: ArrayTreeLike,
63
+ get_init_log_density: GetInitLogDensity,
64
+ get_observation_func: GetObservationFunc,
65
+ rtol: float | None = None,
66
+ ignore_nan_dims: bool = False,
67
+ key: KeyArray | None = None,
68
+ ) -> LinearizedKalmanFilterState:
69
+ """Prepare the initial state for the linearized Taylor Kalman filter.
70
+
71
+ Args:
72
+ model_inputs: Model inputs.
73
+ get_init_log_density: Function that returns log density log p(x_0)
74
+ and linearization point.
75
+ get_observation_func: Function that returns either
76
+ - An observation log density
77
+ function log p(y_0 | x_0) as well as points x_0 and y_0
78
+ to linearize around.
79
+ - A log potential function log G(x_0) and a linearization point x_0.
80
+ rtol: The relative tolerance for the singular values of precision matrices
81
+ when passed to `symmetric_inv_sqrt` during linearization.
82
+ Cutoff for small singular values; singular values smaller than
83
+ `rtol * largest_singular_value` are treated as zero.
84
+ The default is determined based on the floating point precision of the dtype.
85
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
86
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
87
+ precision matrices (found via linearization) as missing and ignore all rows
88
+ and columns associated with them.
89
+ key: JAX random key - not used.
90
+
91
+ Returns:
92
+ State for the linearized Taylor Kalman filter.
93
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
94
+ and log_normalizing_constant.
95
+ """
96
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
97
+ init_log_density, linearization_point = get_init_log_density(model_inputs)
98
+
99
+ _, m0, chol_P0 = linearize_log_density(
100
+ lambda _, x: init_log_density(x),
101
+ linearization_point,
102
+ linearization_point,
103
+ rtol=rtol,
104
+ ignore_nan_dims=ignore_nan_dims,
105
+ )
106
+
107
+ prior_state = linearized_kalman_filter_state_dummy_elem(
108
+ mean=m0,
109
+ chol_cov=chol_P0,
110
+ log_normalizing_constant=jnp.array(0.0),
111
+ model_inputs=model_inputs,
112
+ mean_prev=dummy_tree_like(m0),
113
+ )
114
+
115
+ observation_output = get_observation_func(prior_state, model_inputs)
116
+
117
+ H, d, chol_R, observation = process_observation(
118
+ observation_output,
119
+ rtol=rtol,
120
+ ignore_nan_dims=ignore_nan_dims,
121
+ )
122
+
123
+ (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, observation)
124
+
125
+ return linearized_kalman_filter_state_dummy_elem(
126
+ mean=m,
127
+ chol_cov=chol_P,
128
+ log_normalizing_constant=ell,
129
+ model_inputs=model_inputs,
130
+ mean_prev=dummy_tree_like(m),
131
+ )
132
+
133
+
134
+ def filter_prepare(
135
+ model_inputs: ArrayTreeLike,
136
+ get_init_log_density: GetInitLogDensity,
137
+ key: KeyArray | None = None,
138
+ ) -> LinearizedKalmanFilterState:
139
+ """Prepare a state for a linearized Taylor Kalman filter step.
140
+
141
+ Just passes through model inputs.
142
+
143
+ Args:
144
+ model_inputs: Model inputs.
145
+ get_init_log_density: Function that returns log density log p(x_0)
146
+ and linearization point. Only used to infer shape of mean and chol_cov.
147
+ key: JAX random key - not used.
148
+
149
+ Returns:
150
+ Prepared state for linearized Taylor Kalman filter.
151
+ """
152
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
153
+ dummy_mean_struct = eval_shape(lambda mi: get_init_log_density(mi)[1], model_inputs)
154
+ dummy_mean = dummy_tree_like(dummy_mean_struct)
155
+ dummy_chol_cov = dummy_tree_like(jnp.cov(dummy_mean[..., None]))
156
+
157
+ return linearized_kalman_filter_state_dummy_elem(
158
+ mean=dummy_mean,
159
+ chol_cov=dummy_chol_cov,
160
+ log_normalizing_constant=jnp.array(0.0),
161
+ model_inputs=model_inputs,
162
+ mean_prev=dummy_mean,
163
+ )
164
+
165
+
166
+ def filter_combine(
167
+ state_1: LinearizedKalmanFilterState,
168
+ state_2: LinearizedKalmanFilterState,
169
+ get_dynamics_log_density: GetDynamicsLogDensity,
170
+ get_observation_func: GetObservationFunc,
171
+ rtol: float | None = None,
172
+ ignore_nan_dims: bool = False,
173
+ ) -> LinearizedKalmanFilterState:
174
+ """Combine previous filter state with state prepared from latest model inputs.
175
+
176
+ Applies linearized Taylor Kalman predict + filter update in covariance square
177
+ root form.
178
+ Not suitable for associative scan.
179
+
180
+ Args:
181
+ state_1: State from previous time step.
182
+ state_2: State prepared (only access model_inputs attribute).
183
+ get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
184
+ and linearization points (for the previous and current time points)
185
+ get_observation_func: Function to get observation function (either conditional
186
+ log density or log potential), linearization point and optional observation
187
+ (not required for log potential functions).
188
+ rtol: The relative tolerance for the singular values of precision matrices
189
+ when passed to `symmetric_inv_sqrt` during linearization.
190
+ Cutoff for small singular values; singular values smaller than
191
+ `rtol * largest_singular_value` are treated as zero.
192
+ The default is determined based on the floating point precision of the dtype.
193
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
194
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
195
+ precision matrices (found via linearization) as missing and ignore all rows
196
+ and columns associated with them.
197
+
198
+ Returns:
199
+ Predicted and updated linearized Taylor Kalman filter state.
200
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
201
+ and log_normalizing_constant.
202
+ """
203
+ log_dynamics_density, linearization_point_prev, linearization_point_curr = (
204
+ get_dynamics_log_density(state_1, state_2.model_inputs)
205
+ )
206
+
207
+ F, c, chol_Q = linearize_log_density(
208
+ log_dynamics_density,
209
+ linearization_point_prev,
210
+ linearization_point_curr,
211
+ rtol=rtol,
212
+ ignore_nan_dims=ignore_nan_dims,
213
+ )
214
+
215
+ predict_mean, predict_chol_cov = filtering.predict(
216
+ state_1.mean, state_1.chol_cov, F, c, chol_Q
217
+ )
218
+
219
+ predict_state = linearized_kalman_filter_state_dummy_elem(
220
+ mean=predict_mean,
221
+ chol_cov=predict_chol_cov,
222
+ log_normalizing_constant=state_1.log_normalizing_constant,
223
+ model_inputs=state_2.model_inputs,
224
+ mean_prev=state_1.mean,
225
+ )
226
+
227
+ observation_output = get_observation_func(predict_state, state_2.model_inputs)
228
+
229
+ H, d, chol_R, observation = process_observation(
230
+ observation_output,
231
+ rtol=rtol,
232
+ ignore_nan_dims=ignore_nan_dims,
233
+ )
234
+
235
+ (update_mean, update_chol_cov), log_normalizing_constant = filtering.update(
236
+ predict_mean, predict_chol_cov, H, d, chol_R, observation
237
+ )
238
+
239
+ return linearized_kalman_filter_state_dummy_elem(
240
+ mean=update_mean,
241
+ chol_cov=update_chol_cov,
242
+ log_normalizing_constant=state_1.log_normalizing_constant
243
+ + log_normalizing_constant,
244
+ model_inputs=state_2.model_inputs,
245
+ mean_prev=state_1.mean,
246
+ )