cuthbert 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. {cuthbert-0.0.2 → cuthbert-0.0.3}/PKG-INFO +1 -1
  2. cuthbert-0.0.3/cuthbert/discrete/__init__.py +2 -0
  3. cuthbert-0.0.3/cuthbert/discrete/filter.py +140 -0
  4. cuthbert-0.0.3/cuthbert/discrete/smoother.py +123 -0
  5. cuthbert-0.0.3/cuthbert/discrete/types.py +53 -0
  6. cuthbert-0.0.3/cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert-0.0.3/cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert-0.0.3/cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert-0.0.3/cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert-0.0.3/cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert-0.0.3/cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert-0.0.3/cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert-0.0.3/cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert-0.0.3/cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert-0.0.3/cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert-0.0.3/cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert-0.0.3/cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert-0.0.3/cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert-0.0.3/cuthbert/gaussian/types.py +57 -0
  20. cuthbert-0.0.3/cuthbert/gaussian/utils.py +41 -0
  21. cuthbert-0.0.3/cuthbert/smc/__init__.py +0 -0
  22. cuthbert-0.0.3/cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert-0.0.3/cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert-0.0.3/cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert-0.0.3/cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert.egg-info/PKG-INFO +1 -1
  27. cuthbert-0.0.3/cuthbert.egg-info/SOURCES.txt +80 -0
  28. cuthbert-0.0.3/cuthbertlib/__init__.py +0 -0
  29. cuthbert-0.0.3/cuthbertlib/discrete/__init__.py +0 -0
  30. cuthbert-0.0.3/cuthbertlib/discrete/filtering.py +49 -0
  31. cuthbert-0.0.3/cuthbertlib/discrete/smoothing.py +35 -0
  32. cuthbert-0.0.3/cuthbertlib/kalman/__init__.py +4 -0
  33. cuthbert-0.0.3/cuthbertlib/kalman/filtering.py +213 -0
  34. cuthbert-0.0.3/cuthbertlib/kalman/generate.py +85 -0
  35. cuthbert-0.0.3/cuthbertlib/kalman/sampling.py +68 -0
  36. cuthbert-0.0.3/cuthbertlib/kalman/smoothing.py +121 -0
  37. cuthbert-0.0.3/cuthbertlib/linalg/__init__.py +7 -0
  38. cuthbert-0.0.3/cuthbertlib/linalg/collect_nans_chol.py +90 -0
  39. cuthbert-0.0.3/cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  40. cuthbert-0.0.3/cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  41. cuthbert-0.0.3/cuthbertlib/linalg/tria.py +21 -0
  42. cuthbert-0.0.3/cuthbertlib/linearize/__init__.py +7 -0
  43. cuthbert-0.0.3/cuthbertlib/linearize/log_density.py +175 -0
  44. cuthbert-0.0.3/cuthbertlib/linearize/moments.py +94 -0
  45. cuthbert-0.0.3/cuthbertlib/linearize/taylor.py +83 -0
  46. cuthbert-0.0.3/cuthbertlib/quadrature/__init__.py +4 -0
  47. cuthbert-0.0.3/cuthbertlib/quadrature/common.py +102 -0
  48. cuthbert-0.0.3/cuthbertlib/quadrature/cubature.py +73 -0
  49. cuthbert-0.0.3/cuthbertlib/quadrature/gauss_hermite.py +62 -0
  50. cuthbert-0.0.3/cuthbertlib/quadrature/linearize.py +143 -0
  51. cuthbert-0.0.3/cuthbertlib/quadrature/unscented.py +79 -0
  52. cuthbert-0.0.3/cuthbertlib/quadrature/utils.py +109 -0
  53. cuthbert-0.0.3/cuthbertlib/resampling/__init__.py +3 -0
  54. cuthbert-0.0.3/cuthbertlib/resampling/killing.py +79 -0
  55. cuthbert-0.0.3/cuthbertlib/resampling/multinomial.py +53 -0
  56. cuthbert-0.0.3/cuthbertlib/resampling/protocols.py +92 -0
  57. cuthbert-0.0.3/cuthbertlib/resampling/systematic.py +78 -0
  58. cuthbert-0.0.3/cuthbertlib/resampling/utils.py +82 -0
  59. cuthbert-0.0.3/cuthbertlib/smc/__init__.py +0 -0
  60. cuthbert-0.0.3/cuthbertlib/smc/ess.py +24 -0
  61. cuthbert-0.0.3/cuthbertlib/smc/smoothing/__init__.py +0 -0
  62. cuthbert-0.0.3/cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  63. cuthbert-0.0.3/cuthbertlib/smc/smoothing/mcmc.py +76 -0
  64. cuthbert-0.0.3/cuthbertlib/smc/smoothing/protocols.py +44 -0
  65. cuthbert-0.0.3/cuthbertlib/smc/smoothing/tracing.py +45 -0
  66. cuthbert-0.0.3/cuthbertlib/stats/__init__.py +0 -0
  67. cuthbert-0.0.3/cuthbertlib/stats/multivariate_normal.py +102 -0
  68. {cuthbert-0.0.2 → cuthbert-0.0.3}/pyproject.toml +2 -2
  69. cuthbert-0.0.2/cuthbert.egg-info/SOURCES.txt +0 -16
  70. {cuthbert-0.0.2 → cuthbert-0.0.3}/LICENSE +0 -0
  71. {cuthbert-0.0.2 → cuthbert-0.0.3}/README.md +0 -0
  72. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert/__init__.py +0 -0
  73. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert/filtering.py +0 -0
  74. {cuthbert-0.0.2/cuthbertlib → cuthbert-0.0.3/cuthbert/gaussian}/__init__.py +0 -0
  75. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert/inference.py +0 -0
  76. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert/smoothing.py +0 -0
  77. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert/utils.py +0 -0
  78. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert.egg-info/dependency_links.txt +0 -0
  79. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert.egg-info/requires.txt +0 -0
  80. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbert.egg-info/top_level.txt +0 -0
  81. {cuthbert-0.0.2 → cuthbert-0.0.3}/cuthbertlib/types.py +0 -0
  82. {cuthbert-0.0.2 → cuthbert-0.0.3}/setup.cfg +0 -0
  83. {cuthbert-0.0.2 → cuthbert-0.0.3}/tests/test_examples_scripts.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cuthbert
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: State-space model inference with JAX
5
5
  Author-email: Sam Duffield <s@mduffield.com>, Sahel Iqbal <sahel13miqbal@proton.me>, Adrien Corenflos <adrien.corenflos.stats@gmail.com>
6
6
  License: Apache-2.0
@@ -0,0 +1,2 @@
1
+ from cuthbert.discrete.filter import build_filter
2
+ from cuthbert.discrete.smoother import build_smoother
@@ -0,0 +1,140 @@
1
+ """Parallel-in-time Bayesian filter for discrete hidden Markov models.
2
+
3
+ References:
4
+ - https://ieeexplore.ieee.org/document/9512397
5
+ - https://github.com/EEA-sensors/sequential-parallelization-examples/tree/main/python/temporal-parallelization-inference-in-HMMs
6
+ - https://github.com/probml/dynamax/blob/main/dynamax/hidden_markov_model/parallel_inference.py
7
+ """
8
+
9
+ from functools import partial
10
+ from typing import NamedTuple
11
+
12
+ import jax.numpy as jnp
13
+ from jax import tree
14
+
15
+ from cuthbert.discrete.types import (
16
+ GetInitDist,
17
+ GetObsLogLikelihoods,
18
+ GetTransitionMatrix,
19
+ )
20
+ from cuthbert.inference import Filter
21
+ from cuthbertlib.discrete import filtering
22
+ from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
23
+
24
+
25
+ class DiscreteFilterState(NamedTuple):
26
+ """Discrete filter state."""
27
+
28
+ elem: filtering.FilterScanElement
29
+ model_inputs: ArrayTree
30
+
31
+ @property
32
+ def dist(self) -> Array:
33
+ """The filtered distribution.
34
+
35
+ Has shape (K,) or (T+1, K) where K is the number of possible states.
36
+ """
37
+ return jnp.take(self.elem.f, 0, axis=-2)
38
+
39
+ @property
40
+ def log_normalizing_constant(self) -> Array:
41
+ """Log normalizing constant (cumulative)."""
42
+ return jnp.take(self.elem.log_g, 0, axis=-1)
43
+
44
+
45
+ def build_filter(
46
+ get_init_dist: GetInitDist,
47
+ get_trans_matrix: GetTransitionMatrix,
48
+ get_obs_lls: GetObsLogLikelihoods,
49
+ ) -> Filter:
50
+ r"""Builds a filter object for discrete hidden Markov models.
51
+
52
+ Args:
53
+ get_init_dist: Function to get initial state probabilities $m_i = p(x_0 = i)$.
54
+ get_trans_matrix: Function to get the transition matrix $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
55
+ get_obs_lls: Function to get observation log likelihoods $b_i = \log p(y_t | x_t = i)$.
56
+
57
+ Returns:
58
+ Filter object. Suitable for associative scan.
59
+ """
60
+ return Filter(
61
+ init_prepare=partial(
62
+ init_prepare, get_init_dist=get_init_dist, get_obs_lls=get_obs_lls
63
+ ),
64
+ filter_prepare=partial(
65
+ filter_prepare, get_trans_matrix=get_trans_matrix, get_obs_lls=get_obs_lls
66
+ ),
67
+ filter_combine=filter_combine,
68
+ associative=True,
69
+ )
70
+
71
+
72
+ def init_prepare(
73
+ model_inputs: ArrayTreeLike,
74
+ get_init_dist: GetInitDist,
75
+ get_obs_lls: GetObsLogLikelihoods,
76
+ key: KeyArray | None = None,
77
+ ) -> DiscreteFilterState:
78
+ """Prepare the initial state for the filter.
79
+
80
+ Args:
81
+ model_inputs: Model inputs.
82
+ get_init_dist: Function to get initial state probabilities m_i = p(x_0 = i).
83
+ get_obs_lls: Function to get observation log likelihoods b_i = log p(y_t | x_t = i).
84
+ key: JAX random key - not used.
85
+
86
+ Returns:
87
+ Prepared state for the filter.
88
+ """
89
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
90
+ init_dist = get_init_dist(model_inputs)
91
+ obs_lls = get_obs_lls(model_inputs)
92
+ f, log_g = filtering.condition_on_obs(init_dist, obs_lls)
93
+ N = init_dist.shape[0]
94
+ f *= jnp.ones((N, N))
95
+ log_g *= jnp.ones(N)
96
+ return DiscreteFilterState(
97
+ elem=filtering.FilterScanElement(f, log_g), model_inputs=model_inputs
98
+ )
99
+
100
+
101
+ def filter_prepare(
102
+ model_inputs: ArrayTreeLike,
103
+ get_trans_matrix: GetTransitionMatrix,
104
+ get_obs_lls: GetObsLogLikelihoods,
105
+ key: KeyArray | None = None,
106
+ ) -> DiscreteFilterState:
107
+ """Prepare a state for a filter step.
108
+
109
+ Args:
110
+ model_inputs: Model inputs.
111
+ get_trans_matrix: Function to get the transition matrix A_{ij} = p(x_t = j | x_{t-1} = i).
112
+ get_obs_lls: Function to get observation log likelihoods b_i = log p(y_t | x_t = i).
113
+ key: JAX random key - not used.
114
+
115
+ Returns:
116
+ Prepared state for the filter.
117
+ """
118
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
119
+ trans_matrix = get_trans_matrix(model_inputs)
120
+ obs_lls = get_obs_lls(model_inputs)
121
+ f, log_g = filtering.condition_on_obs(trans_matrix, obs_lls)
122
+ return DiscreteFilterState(
123
+ elem=filtering.FilterScanElement(f, log_g), model_inputs=model_inputs
124
+ )
125
+
126
+
127
+ def filter_combine(
128
+ state_1: DiscreteFilterState, state_2: DiscreteFilterState
129
+ ) -> DiscreteFilterState:
130
+ """Combine previous filter state with state prepared with latest model inputs.
131
+
132
+ Args:
133
+ state_1: State from previous time step.
134
+ state_2: State prepared (only access model_inputs attribute).
135
+
136
+ Returns:
137
+ Combined filter state. Contains distribution and log_normalizing_constant.
138
+ """
139
+ combined_elem = filtering.filtering_operator(state_1.elem, state_2.elem)
140
+ return DiscreteFilterState(elem=combined_elem, model_inputs=state_2.model_inputs)
@@ -0,0 +1,123 @@
1
+ """Parallel-in-time Bayesian smoother for discrete hidden Markov models.
2
+
3
+ References:
4
+ - https://ieeexplore.ieee.org/document/9512397
5
+ - https://github.com/EEA-sensors/sequential-parallelization-examples/tree/main/python/temporal-parallelization-inference-in-HMMs
6
+ """
7
+
8
+ from functools import partial
9
+ from typing import NamedTuple
10
+
11
+ import jax.numpy as jnp
12
+ from jax import tree
13
+
14
+ from cuthbert.discrete.filter import DiscreteFilterState
15
+ from cuthbert.discrete.types import GetTransitionMatrix
16
+ from cuthbert.inference import Smoother
17
+ from cuthbert.utils import dummy_tree_like
18
+ from cuthbertlib.discrete.smoothing import get_reverse_kernel, smoothing_operator
19
+ from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
20
+
21
+
22
+ class DiscreteSmootherState(NamedTuple):
23
+ """Discrete smoother state."""
24
+
25
+ a: Array
26
+ model_inputs: ArrayTree
27
+
28
+ @property
29
+ def dist(self):
30
+ """The smoothed distribution.
31
+
32
+ Has shape (K,) or (T+1, K) where K is the number of possible states.
33
+ """
34
+ return jnp.take(self.a, 0, axis=-2)
35
+
36
+
37
+ def build_smoother(get_trans_matrix: GetTransitionMatrix) -> Smoother:
38
+ r"""Builds a smoother object for discrete hidden Markov models.
39
+
40
+ Args:
41
+ get_trans_matrix: Function to get the transition matrix $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
42
+
43
+ Returns:
44
+ Smoother object. Suitable for associative scan.
45
+ """
46
+ return Smoother(
47
+ convert_filter_to_smoother_state=convert_filter_to_smoother_state,
48
+ smoother_prepare=partial(smoother_prepare, get_trans_matrix=get_trans_matrix),
49
+ smoother_combine=smoother_combine,
50
+ associative=True,
51
+ )
52
+
53
+
54
+ def smoother_prepare(
55
+ filter_state: DiscreteFilterState,
56
+ get_trans_matrix: GetTransitionMatrix,
57
+ model_inputs: ArrayTreeLike,
58
+ key: KeyArray | None = None,
59
+ ) -> DiscreteSmootherState:
60
+ """Prepare a state for a smoother step.
61
+
62
+ Args:
63
+ filter_state: State generated by the filter at time t.
64
+ get_trans_matrix: Function to get the transition matrix A_{ij} = p(x_{t+1} = j | x_{t} = i).
65
+ model_inputs: Model inputs for the transition from t to t+1.
66
+ key: JAX random key - not used.
67
+
68
+ Returns:
69
+ Prepared state for the smoother.
70
+ """
71
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
72
+ trans_matrix = get_trans_matrix(model_inputs)
73
+ a = get_reverse_kernel(filter_state.dist, trans_matrix)
74
+ return DiscreteSmootherState(a=a, model_inputs=model_inputs)
75
+
76
+
77
+ def convert_filter_to_smoother_state(
78
+ filter_state: DiscreteFilterState,
79
+ model_inputs: ArrayTreeLike | None = None,
80
+ key: KeyArray | None = None,
81
+ ) -> DiscreteSmootherState:
82
+ """Convert a filter state to a smoother state.
83
+
84
+ Useful for the final filter state which is equivalent to the final smoother state.
85
+
86
+ Args:
87
+ filter_state: Filter state.
88
+ model_inputs: Only used to create an empty model_inputs tree (the values are ignored).
89
+ Useful so that the final smoother state has the same structure as the rest.
90
+ By default, filter_state.model_inputs is used. So this is only needed if the
91
+ smoother model_inputs have a different tree structure to filter_state.model_inputs.
92
+ key: JAX random key - not used.
93
+
94
+ Returns:
95
+ Smoother state, same data as filter state just different structure.
96
+ Note that the model_inputs are set to dummy values.
97
+ """
98
+ if model_inputs is None:
99
+ model_inputs = filter_state.model_inputs
100
+
101
+ dummy_model_inputs = dummy_tree_like(model_inputs)
102
+
103
+ filter_dist = filter_state.dist
104
+ a = jnp.tile(filter_dist, (filter_dist.shape[0], 1))
105
+ return DiscreteSmootherState(a=a, model_inputs=dummy_model_inputs)
106
+
107
+
108
+ def smoother_combine(
109
+ state_1: DiscreteSmootherState, state_2: DiscreteSmootherState
110
+ ) -> DiscreteSmootherState:
111
+ """Combine smoother state from next time point with state prepared with latest model inputs.
112
+
113
+ Remember smoothing iterates backwards in time.
114
+
115
+ Args:
116
+ state_1: State prepared with model inputs at time t.
117
+ state_2: Smoother state at time t + 1.
118
+
119
+ Returns:
120
+ Combined smoother state.
121
+ """
122
+ a = smoothing_operator(state_1.a, state_2.a)
123
+ return DiscreteSmootherState(a=a, model_inputs=state_1.model_inputs)
@@ -0,0 +1,53 @@
1
+ """Provides types for representing discrete HMMs."""
2
+
3
+ from typing import Protocol
4
+
5
+ from cuthbertlib.types import Array, ArrayTreeLike
6
+
7
+
8
+ class GetInitDist(Protocol):
9
+ """Protocol for specifying the initial distribution."""
10
+
11
+ def __call__(self, model_inputs: ArrayTreeLike) -> Array:
12
+ """Get the initial distribution.
13
+
14
+ Args:
15
+ model_inputs: Model inputs.
16
+
17
+ Returns:
18
+ An array $m$ of shape (N,) where N is the number of states,
19
+ with $m_i = p(x_0 = i)$.
20
+ """
21
+ ...
22
+
23
+
24
+ class GetTransitionMatrix(Protocol):
25
+ """Protocol for specifying the transition matrix."""
26
+
27
+ def __call__(self, model_inputs: ArrayTreeLike) -> Array:
28
+ r"""Get the transition matrix.
29
+
30
+ Args:
31
+ model_inputs: Model inputs.
32
+
33
+ Returns:
34
+ An array $A$ of shape (N, N) where N is the number of
35
+ states, with $A_{ij} = p(x_t = j \mid x_{t-1} = i)$.
36
+ """
37
+ ...
38
+
39
+
40
+ class GetObsLogLikelihoods(Protocol):
41
+ """Protocol for specifying the observation log likelihoods."""
42
+
43
+ def __call__(self, model_inputs: ArrayTreeLike) -> Array:
44
+ r"""Get the observation log likelihoods.
45
+
46
+ Args:
47
+ model_inputs: Model inputs.
48
+
49
+ Returns:
50
+ An array $b$ of shape (N,) where N is the number of states,
51
+ with $b_i = \log p(y_t \mid x_t = i)$.
52
+ """
53
+ ...
@@ -0,0 +1,337 @@
1
+ """Implements the square-root, parallel-in-time Kalman filter for linear Gaussian SSMs.
2
+
3
+ See [Yaghoobi et. al. (2025)](https://doi.org/10.1137/23M156121X).
4
+ """
5
+
6
+ from functools import partial
7
+ from typing import NamedTuple
8
+
9
+ from jax import numpy as jnp
10
+ from jax import tree
11
+
12
+ from cuthbert.gaussian.types import (
13
+ GetDynamicsParams,
14
+ GetInitParams,
15
+ GetObservationParams,
16
+ )
17
+ from cuthbert.inference import Filter, Smoother
18
+ from cuthbert.utils import dummy_tree_like
19
+ from cuthbertlib.kalman import filtering, smoothing
20
+ from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike, KeyArray
21
+
22
+
23
+ class KalmanFilterState(NamedTuple):
24
+ """Kalman filter state."""
25
+
26
+ elem: filtering.FilterScanElement
27
+ model_inputs: ArrayTree
28
+
29
+ @property
30
+ def mean(self) -> Array:
31
+ """Filtering mean."""
32
+ return self.elem.b
33
+
34
+ @property
35
+ def chol_cov(self) -> Array:
36
+ """Filtering generalised Cholesky covariance."""
37
+ return self.elem.U
38
+
39
+ @property
40
+ def log_normalizing_constant(self) -> Array:
41
+ """Log normalizing constant (cumulative)."""
42
+ return self.elem.ell
43
+
44
+
45
+ class KalmanSmootherState(NamedTuple):
46
+ """Kalman smoother state."""
47
+
48
+ elem: smoothing.SmootherScanElement
49
+ model_inputs: ArrayTree
50
+ gain: Array | None = None
51
+ chol_cov_given_next: Array | None = None
52
+
53
+ @property
54
+ def mean(self) -> Array:
55
+ """Smoothing mean."""
56
+ return self.elem.g
57
+
58
+ @property
59
+ def chol_cov(self) -> Array:
60
+ """Smoothing generalised Cholesky covariance."""
61
+ return self.elem.D
62
+
63
+
64
+ def build_filter(
65
+ get_init_params: GetInitParams,
66
+ get_dynamics_params: GetDynamicsParams,
67
+ get_observation_params: GetObservationParams,
68
+ ) -> Filter:
69
+ """Builds an exact Kalman filter object for linear Gaussian SSMs.
70
+
71
+ Args:
72
+ get_init_params: Function to get m0, chol_P0 to initialize filter state,
73
+ given model inputs sufficient to define p(x_0) = N(m0, chol_P0 @ chol_P0^T).
74
+ get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q
75
+ given model inputs sufficient to define
76
+ p(x_t | x_{t-1}) = N(F @ x_{t-1} + c, chol_Q @ chol_Q^T).
77
+ get_observation_params: Function to get observation parameters, H, d, chol_R, y
78
+ given model inputs sufficient to define
79
+ p(y_t | x_t) = N(H @ x_t + d, chol_R @ chol_R^T).
80
+
81
+ Returns:
82
+ Filter object for exact Kalman filter. Suitable for associative scan.
83
+ """
84
+ return Filter(
85
+ init_prepare=partial(
86
+ init_prepare,
87
+ get_init_params=get_init_params,
88
+ get_observation_params=get_observation_params,
89
+ ),
90
+ filter_prepare=partial(
91
+ filter_prepare,
92
+ get_dynamics_params=get_dynamics_params,
93
+ get_observation_params=get_observation_params,
94
+ ),
95
+ filter_combine=filter_combine,
96
+ associative=True,
97
+ )
98
+
99
+
100
+ def build_smoother(
101
+ get_dynamics_params: GetDynamicsParams,
102
+ store_gain: bool = False,
103
+ store_chol_cov_given_next: bool = False,
104
+ ) -> Smoother:
105
+ """Builds an exact Kalman smoother object for linear Gaussian SSMs.
106
+
107
+ Args:
108
+ get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q
109
+ given model inputs sufficient to define
110
+ p(x_t | x_{t-1}) = N(F @ x_{t-1} + c, chol_Q @ chol_Q^T).
111
+ store_gain: Whether to store the gain matrix in the smoother state.
112
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
113
+ in the smoother state.
114
+
115
+ Returns:
116
+ Smoother object for exact Kalman smoother. Suitable for associative scan.
117
+ """
118
+ return Smoother(
119
+ convert_filter_to_smoother_state=partial(
120
+ convert_filter_to_smoother_state,
121
+ store_gain=store_gain,
122
+ store_chol_cov_given_next=store_chol_cov_given_next,
123
+ ),
124
+ smoother_prepare=partial(
125
+ smoother_prepare,
126
+ get_dynamics_params=get_dynamics_params,
127
+ store_gain=store_gain,
128
+ store_chol_cov_given_next=store_chol_cov_given_next,
129
+ ),
130
+ smoother_combine=smoother_combine,
131
+ associative=True,
132
+ )
133
+
134
+
135
+ def init_prepare(
136
+ model_inputs: ArrayTreeLike,
137
+ get_init_params: GetInitParams,
138
+ get_observation_params: GetObservationParams,
139
+ key: KeyArray | None = None,
140
+ ) -> KalmanFilterState:
141
+ """Prepare the initial state for the Kalman filter.
142
+
143
+ Args:
144
+ model_inputs: Model inputs.
145
+ get_init_params: Function to get m0, chol_P0 from model inputs.
146
+ get_observation_params: Function to get observation parameters, H, d, chol_R, y.
147
+ key: JAX random key - not used.
148
+
149
+ Returns:
150
+ State for the Kalman filter.
151
+ Contains mean and chol_cov (generalised Cholesky factor of covariance).
152
+ """
153
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
154
+ m0, chol_P0 = get_init_params(model_inputs)
155
+ H, d, chol_R, y = get_observation_params(model_inputs)
156
+
157
+ (m, chol_P), ell = filtering.update(m0, chol_P0, H, d, chol_R, y)
158
+ elem = filtering.FilterScanElement(
159
+ A=jnp.zeros_like(chol_P),
160
+ b=m,
161
+ U=chol_P,
162
+ eta=jnp.zeros_like(m),
163
+ Z=jnp.zeros_like(chol_P),
164
+ ell=ell,
165
+ )
166
+ return KalmanFilterState(elem=elem, model_inputs=model_inputs)
167
+
168
+
169
+ def filter_prepare(
170
+ model_inputs: ArrayTreeLike,
171
+ get_dynamics_params: GetDynamicsParams,
172
+ get_observation_params: GetObservationParams,
173
+ key: KeyArray | None = None,
174
+ ) -> KalmanFilterState:
175
+ """Prepare a state for an exact Kalman filter step.
176
+
177
+ Args:
178
+ model_inputs: Model inputs.
179
+ get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q.
180
+ get_observation_params: Function to get observation parameters, H, d, chol_R, y.
181
+ key: JAX random key - not used.
182
+
183
+ Returns:
184
+ Prepared state for Kalman filter.
185
+ """
186
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
187
+ F, c, chol_Q = get_dynamics_params(model_inputs)
188
+ H, d, chol_R, y = get_observation_params(model_inputs)
189
+ elem = filtering.associative_params_single(F, c, chol_Q, H, d, chol_R, y)
190
+ return KalmanFilterState(elem=elem, model_inputs=model_inputs)
191
+
192
+
193
+ def filter_combine(
194
+ state_1: KalmanFilterState,
195
+ state_2: KalmanFilterState,
196
+ ) -> KalmanFilterState:
197
+ """Combine previous filter state with state prepared with latest model inputs.
198
+
199
+ Applies exact Kalman predict + filter update in covariance square root form.
200
+ Suitable for associative scan (as well as sequential scan).
201
+
202
+ Args:
203
+ state_1: State from previous time step.
204
+ state_2: State prepared with latest model inputs.
205
+
206
+ Returns:
207
+ Combined Kalman filter state.
208
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
209
+ and log_normalizing_constant.
210
+ """
211
+ combined_elem = filtering.filtering_operator(
212
+ state_1.elem,
213
+ state_2.elem,
214
+ )
215
+ return KalmanFilterState(elem=combined_elem, model_inputs=state_2.model_inputs)
216
+
217
+
218
+ def smoother_prepare(
219
+ filter_state: KalmanFilterState,
220
+ get_dynamics_params: GetDynamicsParams,
221
+ model_inputs: ArrayTreeLike,
222
+ store_gain: bool = False,
223
+ store_chol_cov_given_next: bool = False,
224
+ key: KeyArray | None = None,
225
+ ) -> KalmanSmootherState:
226
+ """Prepare a state for an exact Kalman smoother step.
227
+
228
+ Note that the model_inputs here are different to filter_state.model_inputs.
229
+ The model_inputs required here are for the transition from t to t+1.
230
+ filter_state.model_inputs represents the transition from t-1 to t.
231
+
232
+ Args:
233
+ filter_state: State generated by the Kalman filter at time t.
234
+ get_dynamics_params: Function to get dynamics parameters, F, c, chol_Q,
235
+ from model inputs.
236
+ model_inputs: Model inputs for the transition from t to t+1.
237
+ store_gain: Whether to store the gain matrix in the smoother state.
238
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
239
+ in the smoother state.
240
+ key: JAX random key - not used.
241
+
242
+ Returns:
243
+ Prepared state for the Kalman smoother.
244
+ """
245
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
246
+ F, c, chol_Q = get_dynamics_params(model_inputs)
247
+ filter_mean = filter_state.mean
248
+ filter_chol_cov = filter_state.chol_cov
249
+ state = smoothing.associative_params_single(
250
+ filter_mean, filter_chol_cov, F, c, chol_Q
251
+ )
252
+ return KalmanSmootherState(
253
+ elem=state,
254
+ gain=state.E if store_gain else None,
255
+ chol_cov_given_next=state.D if store_chol_cov_given_next else None,
256
+ model_inputs=model_inputs,
257
+ )
258
+
259
+
260
+ def smoother_combine(
261
+ state_1: KalmanSmootherState,
262
+ state_2: KalmanSmootherState,
263
+ ) -> KalmanSmootherState:
264
+ """Combine smoother state from next time point with state prepared with latest model inputs.
265
+
266
+ Remember smoothing iterates backwards in time.
267
+
268
+ Applies exact Kalman smoother update in covariance square root form.
269
+ Suitable for associative scan (as well as sequential scan).
270
+
271
+ Args:
272
+ state_1: State prepared with model inputs at time t.
273
+ state_2: Smoother state at time t + 1.
274
+
275
+ Returns:
276
+ Combined Kalman smoother state.
277
+ Contains mean, chol_cov (generalised Cholesky factor of covariance)
278
+ and gain (which can be used to compute temporal cross-covariance).
279
+ """
280
+ state_elem = smoothing.smoothing_operator(
281
+ state_2.elem,
282
+ state_1.elem,
283
+ )
284
+ return KalmanSmootherState(
285
+ elem=state_elem,
286
+ gain=state_1.gain,
287
+ chol_cov_given_next=state_1.chol_cov_given_next,
288
+ model_inputs=state_1.model_inputs,
289
+ )
290
+
291
+
292
+ def convert_filter_to_smoother_state(
293
+ filter_state: ArrayTreeLike,
294
+ model_inputs: ArrayTreeLike | None = None,
295
+ store_gain: bool = False,
296
+ store_chol_cov_given_next: bool = False,
297
+ key: KeyArray | None = None,
298
+ ) -> KalmanSmootherState:
299
+ """Convert the filter state to a smoother state.
300
+
301
+ Useful for the final filter state which is equivalent to the final smoother state.
302
+
303
+ Args:
304
+ filter_state: Filter state.
305
+ model_inputs: Only used to create an empty model_inputs tree
306
+ (the values are ignored).
307
+ Useful so that the final smoother state has the same structure as the rest.
308
+ By default, filter_state.model_inputs is used. So this
309
+ is only needed if the smoother model_inputs have a different tree
310
+ structure to filter_state.model_inputs.
311
+ store_gain: Whether to store the gain matrix in the smoother state.
312
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
313
+ in the smoother state.
314
+ key: JAX random key - not used.
315
+
316
+ Returns:
317
+ Smoother state, same data as filter state just different structure.
318
+ Note that the model_inputs are set to dummy values.
319
+ """
320
+ if model_inputs is None:
321
+ model_inputs = filter_state.model_inputs
322
+
323
+ dummy_model_inputs = dummy_tree_like(model_inputs)
324
+
325
+ elem = smoothing.SmootherScanElement(
326
+ g=filter_state.mean,
327
+ D=filter_state.chol_cov,
328
+ E=jnp.zeros_like(filter_state.chol_cov),
329
+ )
330
+ return KalmanSmootherState(
331
+ elem=elem,
332
+ gain=dummy_tree_like(filter_state.chol_cov) if store_gain else None,
333
+ chol_cov_given_next=dummy_tree_like(filter_state.chol_cov)
334
+ if store_chol_cov_given_next
335
+ else None,
336
+ model_inputs=dummy_model_inputs,
337
+ )
@@ -0,0 +1,11 @@
1
+ from cuthbert.gaussian.moments import (
2
+ associative_filter,
3
+ non_associative_filter,
4
+ smoother,
5
+ )
6
+ from cuthbert.gaussian.moments.filter import build_filter
7
+ from cuthbert.gaussian.moments.smoother import build_smoother
8
+ from cuthbert.gaussian.moments.types import (
9
+ GetDynamicsMoments,
10
+ GetObservationMoments,
11
+ )