cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. cuthbertlib/discrete/__init__.py +0 -0
  29. cuthbertlib/discrete/filtering.py +49 -0
  30. cuthbertlib/discrete/smoothing.py +35 -0
  31. cuthbertlib/kalman/__init__.py +4 -0
  32. cuthbertlib/kalman/filtering.py +213 -0
  33. cuthbertlib/kalman/generate.py +85 -0
  34. cuthbertlib/kalman/sampling.py +68 -0
  35. cuthbertlib/kalman/smoothing.py +121 -0
  36. cuthbertlib/linalg/__init__.py +7 -0
  37. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  38. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  39. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  40. cuthbertlib/linalg/tria.py +21 -0
  41. cuthbertlib/linearize/__init__.py +7 -0
  42. cuthbertlib/linearize/log_density.py +175 -0
  43. cuthbertlib/linearize/moments.py +94 -0
  44. cuthbertlib/linearize/taylor.py +83 -0
  45. cuthbertlib/quadrature/__init__.py +4 -0
  46. cuthbertlib/quadrature/common.py +102 -0
  47. cuthbertlib/quadrature/cubature.py +73 -0
  48. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  49. cuthbertlib/quadrature/linearize.py +143 -0
  50. cuthbertlib/quadrature/unscented.py +79 -0
  51. cuthbertlib/quadrature/utils.py +109 -0
  52. cuthbertlib/resampling/__init__.py +3 -0
  53. cuthbertlib/resampling/killing.py +79 -0
  54. cuthbertlib/resampling/multinomial.py +53 -0
  55. cuthbertlib/resampling/protocols.py +92 -0
  56. cuthbertlib/resampling/systematic.py +78 -0
  57. cuthbertlib/resampling/utils.py +82 -0
  58. cuthbertlib/smc/__init__.py +0 -0
  59. cuthbertlib/smc/ess.py +24 -0
  60. cuthbertlib/smc/smoothing/__init__.py +0 -0
  61. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  62. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  63. cuthbertlib/smc/smoothing/protocols.py +44 -0
  64. cuthbertlib/smc/smoothing/tracing.py +45 -0
  65. cuthbertlib/stats/__init__.py +0 -0
  66. cuthbertlib/stats/multivariate_normal.py +102 -0
  67. cuthbert-0.0.2.dist-info/RECORD +0 -12
  68. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  69. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
  70. {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,158 @@
1
+ r"""Linearized Taylor Kalman smoother.
2
+
3
+ Uses automatic differentiation to extract conditionally Gaussian parameters from log
4
+ densities of the dynamics and observation distributions.
5
+
6
+ This differs from `gaussian/moments`, which requires `mean` and `chol_cov`
7
+ functions as input rather than log densities.
8
+
9
+ I.e., we approximate conditional densities as
10
+
11
+ $$
12
+ p(y \mid x) \approx N(y \mid H x + d, L L^T),
13
+ $$
14
+
15
+ and potentials as
16
+
17
+ $$
18
+ G(x) \approx N(x \mid m, L L^T),
19
+ $$
20
+
21
+ where $L$ is the Cholesky factor of the covariance matrix.
22
+
23
+ See `cuthbertlib.linearize` for more details.
24
+ """
25
+
26
+ from functools import partial
27
+
28
+ from jax import numpy as jnp
29
+ from jax import tree
30
+
31
+ from cuthbert.gaussian.kalman import (
32
+ KalmanSmootherState,
33
+ convert_filter_to_smoother_state,
34
+ smoother_combine,
35
+ )
36
+ from cuthbert.gaussian.taylor.types import (
37
+ GetDynamicsLogDensity,
38
+ )
39
+ from cuthbert.gaussian.types import (
40
+ LinearizedKalmanFilterState,
41
+ )
42
+ from cuthbert.inference import Smoother
43
+ from cuthbertlib.kalman import smoothing
44
+ from cuthbertlib.linearize import linearize_log_density
45
+ from cuthbertlib.types import (
46
+ ArrayTreeLike,
47
+ KeyArray,
48
+ )
49
+
50
+
51
+ def build_smoother(
52
+ get_dynamics_log_density: GetDynamicsLogDensity,
53
+ rtol: float | None = None,
54
+ ignore_nan_dims: bool = False,
55
+ store_gain: bool = False,
56
+ store_chol_cov_given_next: bool = False,
57
+ ) -> Smoother:
58
+ """Build linearized Taylor Kalman inference smoother.
59
+
60
+ Args:
61
+ get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
62
+ and linearization points (for the previous and current time points)
63
+ rtol: The relative tolerance for the singular values of precision matrices
64
+ when passed to `symmetric_inv_sqrt` during linearization.
65
+ Cutoff for small singular values; singular values smaller than
66
+ `rtol * largest_singular_value` are treated as zero.
67
+ The default is determined based on the floating point precision of the dtype.
68
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
69
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
70
+ precision matrices (found via linearization) as missing and ignore all rows
71
+ and columns associated with them.
72
+ store_gain: Whether to store the gain matrix in the smoother state.
73
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
74
+ in the smoother state.
75
+
76
+ Returns:
77
+ Linearized Taylor Kalman smoother object, suitable for associative scan.
78
+ """
79
+ return Smoother(
80
+ smoother_prepare=partial(
81
+ smoother_prepare,
82
+ get_dynamics_log_density=get_dynamics_log_density,
83
+ rtol=rtol,
84
+ ignore_nan_dims=ignore_nan_dims,
85
+ store_gain=store_gain,
86
+ store_chol_cov_given_next=store_chol_cov_given_next,
87
+ ),
88
+ smoother_combine=smoother_combine,
89
+ convert_filter_to_smoother_state=partial(
90
+ convert_filter_to_smoother_state,
91
+ store_gain=store_gain,
92
+ store_chol_cov_given_next=store_chol_cov_given_next,
93
+ ),
94
+ associative=True,
95
+ )
96
+
97
+
98
+ def smoother_prepare(
99
+ filter_state: LinearizedKalmanFilterState,
100
+ get_dynamics_log_density: GetDynamicsLogDensity,
101
+ model_inputs: ArrayTreeLike,
102
+ rtol: float | None = None,
103
+ ignore_nan_dims: bool = False,
104
+ store_gain: bool = False,
105
+ store_chol_cov_given_next: bool = False,
106
+ key: KeyArray | None = None,
107
+ ) -> KalmanSmootherState:
108
+ """Prepare a state for a linearized Taylor Kalman smoother step.
109
+
110
+ Args:
111
+ filter_state: State generated by the linearized Taylor Kalman filter at the previous
112
+ time point.
113
+ get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
114
+ and linearization points (for the previous and current time points)
115
+ model_inputs: Model inputs for the transition from t to t+1.
116
+ rtol: The relative tolerance for the singular values of precision matrices
117
+ when passed to `symmetric_inv_sqrt` during linearization.
118
+ Cutoff for small singular values; singular values smaller than
119
+ `rtol * largest_singular_value` are treated as zero.
120
+ The default is determined based on the floating point precision of the dtype.
121
+ See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
122
+ ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
123
+ precision matrices (found via linearization) as missing and ignore all rows
124
+ and columns associated with them.
125
+ store_gain: Whether to store the gain matrix in the smoother state.
126
+ store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
127
+ in the smoother state.
128
+ key: JAX random key - not used.
129
+
130
+ Returns:
131
+ Prepared state for the Kalman smoother.
132
+ """
133
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
134
+
135
+ filter_mean = filter_state.mean
136
+ filter_chol_cov = filter_state.chol_cov
137
+
138
+ log_dynamics_density, linearization_point_prev, linearization_point_curr = (
139
+ get_dynamics_log_density(filter_state, model_inputs)
140
+ )
141
+
142
+ F, c, chol_Q = linearize_log_density(
143
+ log_dynamics_density,
144
+ linearization_point_prev,
145
+ linearization_point_curr,
146
+ rtol=rtol,
147
+ ignore_nan_dims=ignore_nan_dims,
148
+ )
149
+
150
+ state = smoothing.associative_params_single(
151
+ filter_mean, filter_chol_cov, F, c, chol_Q
152
+ )
153
+ return KalmanSmootherState(
154
+ elem=state,
155
+ gain=state.E if store_gain else None,
156
+ chol_cov_given_next=state.D if store_chol_cov_given_next else None,
157
+ model_inputs=model_inputs,
158
+ )
@@ -0,0 +1,86 @@
1
+ """Provides types for the Taylor-series linearization of Gaussian state-space models."""
2
+
3
+ from typing import Protocol, TypeAlias
4
+
5
+ from cuthbert.gaussian.types import (
6
+ LinearizedKalmanFilterState,
7
+ )
8
+ from cuthbertlib.types import (
9
+ Array,
10
+ ArrayTreeLike,
11
+ LogConditionalDensity,
12
+ LogDensity,
13
+ )
14
+
15
+ LogPotential: TypeAlias = LogDensity
16
+
17
+
18
+ class GetInitLogDensity(Protocol):
19
+ """Protocol for extracting the initial specifications."""
20
+
21
+ def __call__(self, model_inputs: ArrayTreeLike) -> tuple[LogDensity, Array]:
22
+ """Get the initial log density and initial linearization point.
23
+
24
+ Args:
25
+ model_inputs: Model inputs.
26
+
27
+ Returns:
28
+ Tuple with initial log density and initial linearization point.
29
+ """
30
+ ...
31
+
32
+
33
+ class GetDynamicsLogDensity(Protocol):
34
+ """Protocol for extracting the dynamics specifications."""
35
+
36
+ def __call__(
37
+ self,
38
+ state: LinearizedKalmanFilterState,
39
+ model_inputs: ArrayTreeLike,
40
+ ) -> tuple[LogConditionalDensity, Array, Array]:
41
+ """Get the dynamics log density and linearization points.
42
+
43
+ Linearization points required for both the previous and current time points
44
+
45
+ `associative_scan` only supported when `state` is ignored.
46
+
47
+ Args:
48
+ state: NamedTuple containing `mean` and `mean_prev` attributes.
49
+ model_inputs: Model inputs.
50
+
51
+ Returns:
52
+ Tuple with dynamics log density and linearization points.
53
+ """
54
+ ...
55
+
56
+
57
+ class GetObservationFunc(Protocol):
58
+ """Protocol for extracting the required observation specifications."""
59
+
60
+ def __call__(
61
+ self,
62
+ state: LinearizedKalmanFilterState,
63
+ model_inputs: ArrayTreeLike,
64
+ ) -> tuple[LogConditionalDensity, Array, Array] | tuple[LogPotential, Array]:
65
+ """Extract observation function, linearization point and optional observation.
66
+
67
+ State is the predicted state after applying the Kalman dynamics propagation.
68
+
69
+ `associative_scan` only supported when `state` is ignored.
70
+
71
+ Two types of output are supported:
72
+ - Observation log density function log p(y | x) and points x and y
73
+ to linearize around.
74
+ - Log potential function log G(x) and a linearization point x.
75
+
76
+ Args:
77
+ state: NamedTuple containing `mean` and `mean_prev` attributes.
78
+ Predicted state after applying the Kalman dynamics propagation.
79
+ model_inputs: Model inputs.
80
+
81
+ Returns:
82
+ Either a tuple with observation function to linearize, linearization point
83
+ and observation, or a tuple with log potential function and linearization
84
+ point.
85
+ """
86
+ ...
@@ -0,0 +1,57 @@
1
+ """Provides shared types for Gaussian representations in state-space models."""
2
+
3
+ from typing import NamedTuple, Protocol
4
+
5
+ from cuthbertlib.kalman import filtering
6
+ from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike
7
+
8
+
9
+ ### Kalman types
10
+ class GetInitParams(Protocol):
11
+ """Protocol for defining the initial distribution of a linear Gaussian SSM."""
12
+
13
+ def __call__(self, model_inputs: ArrayTreeLike) -> tuple[Array, Array]:
14
+ """Get initial parameters (m0, chol_P0) from model inputs."""
15
+ ...
16
+
17
+
18
+ class GetDynamicsParams(Protocol):
19
+ """Protocol for defining the dynamics model of a linear Gaussian SSM."""
20
+
21
+ def __call__(self, model_inputs: ArrayTreeLike) -> tuple[Array, Array, Array]:
22
+ """Get dynamics parameters (F, c, chol_Q) from model inputs."""
23
+ ...
24
+
25
+
26
+ class GetObservationParams(Protocol):
27
+ """Protocol for defining the observation model of a linear Gaussian SSM."""
28
+
29
+ def __call__(
30
+ self, model_inputs: ArrayTreeLike
31
+ ) -> tuple[Array, Array, Array, Array]:
32
+ """Get observation parameters (H, d, chol_R, y) from model inputs."""
33
+ ...
34
+
35
+
36
+ ### Shared state type for linearized Kalman filters
37
+ class LinearizedKalmanFilterState(NamedTuple):
38
+ """Linearized Kalman filter state."""
39
+
40
+ elem: filtering.FilterScanElement
41
+ model_inputs: ArrayTree
42
+ mean_prev: Array
43
+
44
+ @property
45
+ def mean(self) -> Array:
46
+ """Filtering mean."""
47
+ return self.elem.b
48
+
49
+ @property
50
+ def chol_cov(self) -> Array:
51
+ """Filtering generalised Cholesky covariance."""
52
+ return self.elem.U
53
+
54
+ @property
55
+ def log_normalizing_constant(self) -> Array:
56
+ """Log normalizing constant (cumulative)."""
57
+ return self.elem.ell
@@ -0,0 +1,41 @@
1
+ """Utility functions (dummy state generation) for the Gaussian inference."""
2
+
3
+ from cuthbert.gaussian.types import LinearizedKalmanFilterState
4
+ from cuthbert.utils import dummy_tree_like
5
+ from cuthbertlib.kalman import filtering
6
+ from cuthbertlib.types import Array, ArrayTree
7
+
8
+
9
+ def linearized_kalman_filter_state_dummy_elem(
10
+ mean: Array,
11
+ chol_cov: Array,
12
+ log_normalizing_constant: Array,
13
+ model_inputs: ArrayTree,
14
+ mean_prev: Array,
15
+ ) -> LinearizedKalmanFilterState:
16
+ """Create a LinearizedKalmanFilterState with a dummy element.
17
+
18
+ I.e. when associated scan is not used.
19
+
20
+ Args:
21
+ mean: Mean of the state.
22
+ chol_cov: Cholesky covariance of the state.
23
+ log_normalizing_constant: Log normalizing constant of the state.
24
+ model_inputs: Model inputs.
25
+ mean_prev: Mean of the previous state.
26
+
27
+ Returns:
28
+ LinearizedKalmanFilterState with a dummy elem attribute.
29
+ """
30
+ return LinearizedKalmanFilterState(
31
+ elem=filtering.FilterScanElement(
32
+ A=dummy_tree_like(chol_cov),
33
+ b=mean,
34
+ U=chol_cov,
35
+ eta=dummy_tree_like(mean),
36
+ Z=dummy_tree_like(chol_cov),
37
+ ell=log_normalizing_constant,
38
+ ),
39
+ model_inputs=model_inputs,
40
+ mean_prev=mean_prev,
41
+ )
File without changes
@@ -0,0 +1,193 @@
1
+ """Implements backward sampling for particle filters.
2
+
3
+ Supports 3 different algorithms for backward sampling:
4
+
5
+ - [`cuthbertlib.smc.smoothing.tracing.simulate`][].
6
+ - [`cuthbertlib.smc.smoothing.exact_sampling.simulate`][].
7
+ - [`cuthbertlib.smc.smoothing.mcmc.simulate`][].
8
+ """
9
+
10
+ from functools import partial
11
+ from typing import NamedTuple, cast
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+ from jax import Array, random
16
+
17
+ from cuthbert.inference import Smoother
18
+ from cuthbert.smc.particle_filter import LogPotential, ParticleFilterState
19
+ from cuthbert.utils import dummy_tree_like
20
+ from cuthbertlib.resampling import Resampling
21
+ from cuthbertlib.smc.smoothing.protocols import BackwardSampling
22
+ from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray
23
+
24
+
25
+ class ParticleSmootherState(NamedTuple):
26
+ """Particle smoother state."""
27
+
28
+ key: KeyArray
29
+ particles: ArrayTree
30
+ ancestor_indices: Array
31
+ model_inputs: ArrayTree
32
+ log_weights: Array
33
+
34
+ @property
35
+ def n_particles(self) -> int:
36
+ """Number of particles in the smoother state."""
37
+ return self.ancestor_indices.shape[-1]
38
+
39
+
40
+ def build_smoother(
41
+ log_potential: LogPotential,
42
+ backward_sampling_fn: BackwardSampling,
43
+ resampling_fn: Resampling,
44
+ n_smoother_particles: int,
45
+ ) -> Smoother:
46
+ r"""Build a particle smoother object.
47
+
48
+ Args:
49
+ log_potential: Function to compute the JOINT log potential $\log G_t(x_{t-1}, x_t) + \log M_t(x_t \mid x_{t-1})$.
50
+ backward_sampling_fn: Backward sampling algorithm to use (e.g., genealogy tracing, exact backward sampling).
51
+ This choice specifies how to sample $x_{t-1} \sim p(x_{t-1} \mid x_t, y_{0:t-1})$ given
52
+ samples $x_{t} \sim p(x_t \mid y_{0:T})$. See `cuthbertlib/smc/smoothing/` for possible choices.
53
+ resampling_fn: Resampling algorithm to use (e.g., multinomial, systematic).
54
+ n_smoother_particles: Number of samples to draw from the backward sampling algorithm.
55
+
56
+ Returns:
57
+ Particle smoother object.
58
+ """
59
+ return Smoother(
60
+ convert_filter_to_smoother_state=partial(
61
+ convert_filter_to_smoother_state,
62
+ resampling=resampling_fn,
63
+ n_smoother_particles=n_smoother_particles,
64
+ ),
65
+ smoother_prepare=smoother_prepare,
66
+ smoother_combine=partial(
67
+ smoother_combine,
68
+ backward_sampling_fn=backward_sampling_fn,
69
+ log_potential=log_potential,
70
+ ),
71
+ associative=False,
72
+ )
73
+
74
+
75
+ def convert_filter_to_smoother_state(
76
+ filter_state: ParticleFilterState,
77
+ resampling: Resampling,
78
+ n_smoother_particles: int,
79
+ model_inputs: ArrayTreeLike | None = None,
80
+ key: KeyArray | None = None,
81
+ ) -> ParticleSmootherState:
82
+ """Convert a particle filter state to a particle smoother state.
83
+
84
+ Args:
85
+ filter_state: Particle filter state.
86
+ resampling: Resampling algorithm to use (e.g., multinomial, systematic).
87
+ n_smoother_particles: Number of smoother samples to draw.
88
+ model_inputs: Only used to create an empty model_inputs tree
89
+ (the values are ignored).
90
+ Useful so that the final smoother state has the same structure as the rest.
91
+ By default, filter_state.model_inputs is used. So this
92
+ is only needed if the smoother model_inputs have a different tree
93
+ structure to filter_state.model_inputs.
94
+ key: JAX random key.
95
+
96
+ Returns:
97
+ Particle smoother state. Note that the model_inputs are set to dummy values.
98
+
99
+ Raises:
100
+ ValueError: If key is None.
101
+ """
102
+ if key is None:
103
+ raise ValueError("A JAX PRNG key must be provided.")
104
+
105
+ if model_inputs is None:
106
+ model_inputs = filter_state.model_inputs
107
+
108
+ dummy_model_inputs = dummy_tree_like(model_inputs)
109
+
110
+ key, resampling_key = random.split(key)
111
+ indices = resampling(resampling_key, filter_state.log_weights, n_smoother_particles)
112
+
113
+ return ParticleSmootherState(
114
+ key=cast(KeyArray, key),
115
+ particles=jax.tree.map(lambda z: z[indices], filter_state.particles),
116
+ ancestor_indices=filter_state.ancestor_indices[indices],
117
+ model_inputs=dummy_model_inputs,
118
+ log_weights=-jnp.log(n_smoother_particles) * jnp.ones(n_smoother_particles),
119
+ )
120
+
121
+
122
+ def smoother_prepare(
123
+ filter_state: ParticleFilterState,
124
+ model_inputs: ArrayTreeLike,
125
+ key: KeyArray | None = None,
126
+ ) -> ParticleSmootherState:
127
+ """Prepare a state for a particle smoother step.
128
+
129
+ Note that the model_inputs here are different to filter_state.model_inputs.
130
+ The model_inputs required here are for the transition from t to t+1.
131
+ filter_state.model_inputs represents the transition from t-1 to t.
132
+
133
+ Args:
134
+ filter_state: Particle filter state from time t.
135
+ model_inputs: Model inputs for the transition from t to t+1.
136
+ key: JAX random key.
137
+
138
+ Returns:
139
+ Prepared state for the particle smoother.
140
+ """
141
+ if key is None:
142
+ raise ValueError("A JAX PRNG key must be provided.")
143
+
144
+ model_inputs = jax.tree.map(lambda x: jnp.asarray(x), model_inputs)
145
+
146
+ return ParticleSmootherState(
147
+ key,
148
+ filter_state.particles,
149
+ filter_state.ancestor_indices,
150
+ model_inputs,
151
+ filter_state.log_weights,
152
+ )
153
+
154
+
155
+ def smoother_combine(
156
+ state_1: ParticleSmootherState,
157
+ state_2: ParticleSmootherState,
158
+ backward_sampling_fn: BackwardSampling,
159
+ log_potential: LogPotential,
160
+ ) -> ParticleSmootherState:
161
+ """Combine next smoother state with state prepared with latest model inputs.
162
+
163
+ Remember smoothing iterates backwards in time.
164
+
165
+ Args:
166
+ state_1: State prepared with model inputs at time t.
167
+ state_2: Smoother state at time t + 1.
168
+ backward_sampling_fn: Function to perform backward sampling from the joint distribution.
169
+ log_potential: Function to compute log potential.
170
+
171
+ Returns:
172
+ Combined particle smoother state.
173
+ Contains particles, the original ancestor indices of the particles, and model inputs.
174
+ """
175
+ new_particles_1, ancestors_1 = backward_sampling_fn(
176
+ state_1.key,
177
+ x0_all=state_1.particles,
178
+ x1_all=state_2.particles,
179
+ log_weight_x0_all=state_1.log_weights,
180
+ log_density=lambda s1, s2: log_potential(s1, s2, state_2.model_inputs),
181
+ x1_ancestor_indices=state_2.ancestor_indices,
182
+ )
183
+
184
+ n_particles = len(ancestors_1)
185
+ log_weights = -jnp.log(n_particles) * jnp.ones(n_particles)
186
+ new_state = ParticleSmootherState(
187
+ key=state_1.key,
188
+ particles=new_particles_1,
189
+ ancestor_indices=state_1.ancestor_indices[ancestors_1],
190
+ model_inputs=state_1.model_inputs,
191
+ log_weights=log_weights,
192
+ )
193
+ return new_state