cuthbert 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/METADATA +1 -1
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.2.dist-info/RECORD +0 -12
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {cuthbert-0.0.2.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
r"""Linearized Taylor Kalman smoother.
|
|
2
|
+
|
|
3
|
+
Uses automatic differentiation to extract conditionally Gaussian parameters from log
|
|
4
|
+
densities of the dynamics and observation distributions.
|
|
5
|
+
|
|
6
|
+
This differs from `gaussian/moments`, which requires `mean` and `chol_cov`
|
|
7
|
+
functions as input rather than log densities.
|
|
8
|
+
|
|
9
|
+
I.e., we approximate conditional densities as
|
|
10
|
+
|
|
11
|
+
$$
|
|
12
|
+
p(y \mid x) \approx N(y \mid H x + d, L L^T),
|
|
13
|
+
$$
|
|
14
|
+
|
|
15
|
+
and potentials as
|
|
16
|
+
|
|
17
|
+
$$
|
|
18
|
+
G(x) \approx N(x \mid m, L L^T),
|
|
19
|
+
$$
|
|
20
|
+
|
|
21
|
+
where $L$ is the Cholesky factor of the covariance matrix.
|
|
22
|
+
|
|
23
|
+
See `cuthbertlib.linearize` for more details.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from functools import partial
|
|
27
|
+
|
|
28
|
+
from jax import numpy as jnp
|
|
29
|
+
from jax import tree
|
|
30
|
+
|
|
31
|
+
from cuthbert.gaussian.kalman import (
|
|
32
|
+
KalmanSmootherState,
|
|
33
|
+
convert_filter_to_smoother_state,
|
|
34
|
+
smoother_combine,
|
|
35
|
+
)
|
|
36
|
+
from cuthbert.gaussian.taylor.types import (
|
|
37
|
+
GetDynamicsLogDensity,
|
|
38
|
+
)
|
|
39
|
+
from cuthbert.gaussian.types import (
|
|
40
|
+
LinearizedKalmanFilterState,
|
|
41
|
+
)
|
|
42
|
+
from cuthbert.inference import Smoother
|
|
43
|
+
from cuthbertlib.kalman import smoothing
|
|
44
|
+
from cuthbertlib.linearize import linearize_log_density
|
|
45
|
+
from cuthbertlib.types import (
|
|
46
|
+
ArrayTreeLike,
|
|
47
|
+
KeyArray,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def build_smoother(
|
|
52
|
+
get_dynamics_log_density: GetDynamicsLogDensity,
|
|
53
|
+
rtol: float | None = None,
|
|
54
|
+
ignore_nan_dims: bool = False,
|
|
55
|
+
store_gain: bool = False,
|
|
56
|
+
store_chol_cov_given_next: bool = False,
|
|
57
|
+
) -> Smoother:
|
|
58
|
+
"""Build linearized Taylor Kalman inference smoother.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
|
|
62
|
+
and linearization points (for the previous and current time points)
|
|
63
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
64
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
65
|
+
Cutoff for small singular values; singular values smaller than
|
|
66
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
67
|
+
The default is determined based on the floating point precision of the dtype.
|
|
68
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
69
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
70
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
71
|
+
and columns associated with them.
|
|
72
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
73
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
74
|
+
in the smoother state.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Linearized Taylor Kalman smoother object, suitable for associative scan.
|
|
78
|
+
"""
|
|
79
|
+
return Smoother(
|
|
80
|
+
smoother_prepare=partial(
|
|
81
|
+
smoother_prepare,
|
|
82
|
+
get_dynamics_log_density=get_dynamics_log_density,
|
|
83
|
+
rtol=rtol,
|
|
84
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
85
|
+
store_gain=store_gain,
|
|
86
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
87
|
+
),
|
|
88
|
+
smoother_combine=smoother_combine,
|
|
89
|
+
convert_filter_to_smoother_state=partial(
|
|
90
|
+
convert_filter_to_smoother_state,
|
|
91
|
+
store_gain=store_gain,
|
|
92
|
+
store_chol_cov_given_next=store_chol_cov_given_next,
|
|
93
|
+
),
|
|
94
|
+
associative=True,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def smoother_prepare(
|
|
99
|
+
filter_state: LinearizedKalmanFilterState,
|
|
100
|
+
get_dynamics_log_density: GetDynamicsLogDensity,
|
|
101
|
+
model_inputs: ArrayTreeLike,
|
|
102
|
+
rtol: float | None = None,
|
|
103
|
+
ignore_nan_dims: bool = False,
|
|
104
|
+
store_gain: bool = False,
|
|
105
|
+
store_chol_cov_given_next: bool = False,
|
|
106
|
+
key: KeyArray | None = None,
|
|
107
|
+
) -> KalmanSmootherState:
|
|
108
|
+
"""Prepare a state for a linearized Taylor Kalman smoother step.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
filter_state: State generated by the linearized Taylor Kalman filter at the previous
|
|
112
|
+
time point.
|
|
113
|
+
get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
|
|
114
|
+
and linearization points (for the previous and current time points)
|
|
115
|
+
model_inputs: Model inputs for the transition from t to t+1.
|
|
116
|
+
rtol: The relative tolerance for the singular values of precision matrices
|
|
117
|
+
when passed to `symmetric_inv_sqrt` during linearization.
|
|
118
|
+
Cutoff for small singular values; singular values smaller than
|
|
119
|
+
`rtol * largest_singular_value` are treated as zero.
|
|
120
|
+
The default is determined based on the floating point precision of the dtype.
|
|
121
|
+
See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
|
|
122
|
+
ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
|
|
123
|
+
precision matrices (found via linearization) as missing and ignore all rows
|
|
124
|
+
and columns associated with them.
|
|
125
|
+
store_gain: Whether to store the gain matrix in the smoother state.
|
|
126
|
+
store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
|
|
127
|
+
in the smoother state.
|
|
128
|
+
key: JAX random key - not used.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Prepared state for the Kalman smoother.
|
|
132
|
+
"""
|
|
133
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
134
|
+
|
|
135
|
+
filter_mean = filter_state.mean
|
|
136
|
+
filter_chol_cov = filter_state.chol_cov
|
|
137
|
+
|
|
138
|
+
log_dynamics_density, linearization_point_prev, linearization_point_curr = (
|
|
139
|
+
get_dynamics_log_density(filter_state, model_inputs)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
F, c, chol_Q = linearize_log_density(
|
|
143
|
+
log_dynamics_density,
|
|
144
|
+
linearization_point_prev,
|
|
145
|
+
linearization_point_curr,
|
|
146
|
+
rtol=rtol,
|
|
147
|
+
ignore_nan_dims=ignore_nan_dims,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
state = smoothing.associative_params_single(
|
|
151
|
+
filter_mean, filter_chol_cov, F, c, chol_Q
|
|
152
|
+
)
|
|
153
|
+
return KalmanSmootherState(
|
|
154
|
+
elem=state,
|
|
155
|
+
gain=state.E if store_gain else None,
|
|
156
|
+
chol_cov_given_next=state.D if store_chol_cov_given_next else None,
|
|
157
|
+
model_inputs=model_inputs,
|
|
158
|
+
)
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Provides types for the Taylor-series linearization of Gaussian state-space models."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, TypeAlias
|
|
4
|
+
|
|
5
|
+
from cuthbert.gaussian.types import (
|
|
6
|
+
LinearizedKalmanFilterState,
|
|
7
|
+
)
|
|
8
|
+
from cuthbertlib.types import (
|
|
9
|
+
Array,
|
|
10
|
+
ArrayTreeLike,
|
|
11
|
+
LogConditionalDensity,
|
|
12
|
+
LogDensity,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
LogPotential: TypeAlias = LogDensity
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GetInitLogDensity(Protocol):
|
|
19
|
+
"""Protocol for extracting the initial specifications."""
|
|
20
|
+
|
|
21
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> tuple[LogDensity, Array]:
|
|
22
|
+
"""Get the initial log density and initial linearization point.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
model_inputs: Model inputs.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tuple with initial log density and initial linearization point.
|
|
29
|
+
"""
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GetDynamicsLogDensity(Protocol):
|
|
34
|
+
"""Protocol for extracting the dynamics specifications."""
|
|
35
|
+
|
|
36
|
+
def __call__(
|
|
37
|
+
self,
|
|
38
|
+
state: LinearizedKalmanFilterState,
|
|
39
|
+
model_inputs: ArrayTreeLike,
|
|
40
|
+
) -> tuple[LogConditionalDensity, Array, Array]:
|
|
41
|
+
"""Get the dynamics log density and linearization points.
|
|
42
|
+
|
|
43
|
+
Linearization points required for both the previous and current time points
|
|
44
|
+
|
|
45
|
+
`associative_scan` only supported when `state` is ignored.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
state: NamedTuple containing `mean` and `mean_prev` attributes.
|
|
49
|
+
model_inputs: Model inputs.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tuple with dynamics log density and linearization points.
|
|
53
|
+
"""
|
|
54
|
+
...
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GetObservationFunc(Protocol):
|
|
58
|
+
"""Protocol for extracting the required observation specifications."""
|
|
59
|
+
|
|
60
|
+
def __call__(
|
|
61
|
+
self,
|
|
62
|
+
state: LinearizedKalmanFilterState,
|
|
63
|
+
model_inputs: ArrayTreeLike,
|
|
64
|
+
) -> tuple[LogConditionalDensity, Array, Array] | tuple[LogPotential, Array]:
|
|
65
|
+
"""Extract observation function, linearization point and optional observation.
|
|
66
|
+
|
|
67
|
+
State is the predicted state after applying the Kalman dynamics propagation.
|
|
68
|
+
|
|
69
|
+
`associative_scan` only supported when `state` is ignored.
|
|
70
|
+
|
|
71
|
+
Two types of output are supported:
|
|
72
|
+
- Observation log density function log p(y | x) and points x and y
|
|
73
|
+
to linearize around.
|
|
74
|
+
- Log potential function log G(x) and a linearization point x.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
state: NamedTuple containing `mean` and `mean_prev` attributes.
|
|
78
|
+
Predicted state after applying the Kalman dynamics propagation.
|
|
79
|
+
model_inputs: Model inputs.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Either a tuple with observation function to linearize, linearization point
|
|
83
|
+
and observation, or a tuple with log potential function and linearization
|
|
84
|
+
point.
|
|
85
|
+
"""
|
|
86
|
+
...
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Provides shared types for Gaussian representations in state-space models."""
|
|
2
|
+
|
|
3
|
+
from typing import NamedTuple, Protocol
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.kalman import filtering
|
|
6
|
+
from cuthbertlib.types import Array, ArrayTree, ArrayTreeLike
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
### Kalman types
|
|
10
|
+
class GetInitParams(Protocol):
|
|
11
|
+
"""Protocol for defining the initial distribution of a linear Gaussian SSM."""
|
|
12
|
+
|
|
13
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> tuple[Array, Array]:
|
|
14
|
+
"""Get initial parameters (m0, chol_P0) from model inputs."""
|
|
15
|
+
...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GetDynamicsParams(Protocol):
|
|
19
|
+
"""Protocol for defining the dynamics model of a linear Gaussian SSM."""
|
|
20
|
+
|
|
21
|
+
def __call__(self, model_inputs: ArrayTreeLike) -> tuple[Array, Array, Array]:
|
|
22
|
+
"""Get dynamics parameters (F, c, chol_Q) from model inputs."""
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GetObservationParams(Protocol):
|
|
27
|
+
"""Protocol for defining the observation model of a linear Gaussian SSM."""
|
|
28
|
+
|
|
29
|
+
def __call__(
|
|
30
|
+
self, model_inputs: ArrayTreeLike
|
|
31
|
+
) -> tuple[Array, Array, Array, Array]:
|
|
32
|
+
"""Get observation parameters (H, d, chol_R, y) from model inputs."""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
### Shared state type for linearized Kalman filters
|
|
37
|
+
class LinearizedKalmanFilterState(NamedTuple):
|
|
38
|
+
"""Linearized Kalman filter state."""
|
|
39
|
+
|
|
40
|
+
elem: filtering.FilterScanElement
|
|
41
|
+
model_inputs: ArrayTree
|
|
42
|
+
mean_prev: Array
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def mean(self) -> Array:
|
|
46
|
+
"""Filtering mean."""
|
|
47
|
+
return self.elem.b
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def chol_cov(self) -> Array:
|
|
51
|
+
"""Filtering generalised Cholesky covariance."""
|
|
52
|
+
return self.elem.U
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def log_normalizing_constant(self) -> Array:
|
|
56
|
+
"""Log normalizing constant (cumulative)."""
|
|
57
|
+
return self.elem.ell
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Utility functions (dummy state generation) for the Gaussian inference."""
|
|
2
|
+
|
|
3
|
+
from cuthbert.gaussian.types import LinearizedKalmanFilterState
|
|
4
|
+
from cuthbert.utils import dummy_tree_like
|
|
5
|
+
from cuthbertlib.kalman import filtering
|
|
6
|
+
from cuthbertlib.types import Array, ArrayTree
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def linearized_kalman_filter_state_dummy_elem(
|
|
10
|
+
mean: Array,
|
|
11
|
+
chol_cov: Array,
|
|
12
|
+
log_normalizing_constant: Array,
|
|
13
|
+
model_inputs: ArrayTree,
|
|
14
|
+
mean_prev: Array,
|
|
15
|
+
) -> LinearizedKalmanFilterState:
|
|
16
|
+
"""Create a LinearizedKalmanFilterState with a dummy element.
|
|
17
|
+
|
|
18
|
+
I.e. when associated scan is not used.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
mean: Mean of the state.
|
|
22
|
+
chol_cov: Cholesky covariance of the state.
|
|
23
|
+
log_normalizing_constant: Log normalizing constant of the state.
|
|
24
|
+
model_inputs: Model inputs.
|
|
25
|
+
mean_prev: Mean of the previous state.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
LinearizedKalmanFilterState with a dummy elem attribute.
|
|
29
|
+
"""
|
|
30
|
+
return LinearizedKalmanFilterState(
|
|
31
|
+
elem=filtering.FilterScanElement(
|
|
32
|
+
A=dummy_tree_like(chol_cov),
|
|
33
|
+
b=mean,
|
|
34
|
+
U=chol_cov,
|
|
35
|
+
eta=dummy_tree_like(mean),
|
|
36
|
+
Z=dummy_tree_like(chol_cov),
|
|
37
|
+
ell=log_normalizing_constant,
|
|
38
|
+
),
|
|
39
|
+
model_inputs=model_inputs,
|
|
40
|
+
mean_prev=mean_prev,
|
|
41
|
+
)
|
cuthbert/smc/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
"""Implements backward sampling for particle filters.
|
|
2
|
+
|
|
3
|
+
Supports 3 different algorithms for backward sampling:
|
|
4
|
+
|
|
5
|
+
- [`cuthbertlib.smc.smoothing.tracing.simulate`][].
|
|
6
|
+
- [`cuthbertlib.smc.smoothing.exact_sampling.simulate`][].
|
|
7
|
+
- [`cuthbertlib.smc.smoothing.mcmc.simulate`][].
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from functools import partial
|
|
11
|
+
from typing import NamedTuple, cast
|
|
12
|
+
|
|
13
|
+
import jax
|
|
14
|
+
import jax.numpy as jnp
|
|
15
|
+
from jax import Array, random
|
|
16
|
+
|
|
17
|
+
from cuthbert.inference import Smoother
|
|
18
|
+
from cuthbert.smc.particle_filter import LogPotential, ParticleFilterState
|
|
19
|
+
from cuthbert.utils import dummy_tree_like
|
|
20
|
+
from cuthbertlib.resampling import Resampling
|
|
21
|
+
from cuthbertlib.smc.smoothing.protocols import BackwardSampling
|
|
22
|
+
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ParticleSmootherState(NamedTuple):
|
|
26
|
+
"""Particle smoother state."""
|
|
27
|
+
|
|
28
|
+
key: KeyArray
|
|
29
|
+
particles: ArrayTree
|
|
30
|
+
ancestor_indices: Array
|
|
31
|
+
model_inputs: ArrayTree
|
|
32
|
+
log_weights: Array
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def n_particles(self) -> int:
|
|
36
|
+
"""Number of particles in the smoother state."""
|
|
37
|
+
return self.ancestor_indices.shape[-1]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def build_smoother(
|
|
41
|
+
log_potential: LogPotential,
|
|
42
|
+
backward_sampling_fn: BackwardSampling,
|
|
43
|
+
resampling_fn: Resampling,
|
|
44
|
+
n_smoother_particles: int,
|
|
45
|
+
) -> Smoother:
|
|
46
|
+
r"""Build a particle smoother object.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
log_potential: Function to compute the JOINT log potential $\log G_t(x_{t-1}, x_t) + \log M_t(x_t \mid x_{t-1})$.
|
|
50
|
+
backward_sampling_fn: Backward sampling algorithm to use (e.g., genealogy tracing, exact backward sampling).
|
|
51
|
+
This choice specifies how to sample $x_{t-1} \sim p(x_{t-1} \mid x_t, y_{0:t-1})$ given
|
|
52
|
+
samples $x_{t} \sim p(x_t \mid y_{0:T})$. See `cuthbertlib/smc/smoothing/` for possible choices.
|
|
53
|
+
resampling_fn: Resampling algorithm to use (e.g., multinomial, systematic).
|
|
54
|
+
n_smoother_particles: Number of samples to draw from the backward sampling algorithm.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Particle smoother object.
|
|
58
|
+
"""
|
|
59
|
+
return Smoother(
|
|
60
|
+
convert_filter_to_smoother_state=partial(
|
|
61
|
+
convert_filter_to_smoother_state,
|
|
62
|
+
resampling=resampling_fn,
|
|
63
|
+
n_smoother_particles=n_smoother_particles,
|
|
64
|
+
),
|
|
65
|
+
smoother_prepare=smoother_prepare,
|
|
66
|
+
smoother_combine=partial(
|
|
67
|
+
smoother_combine,
|
|
68
|
+
backward_sampling_fn=backward_sampling_fn,
|
|
69
|
+
log_potential=log_potential,
|
|
70
|
+
),
|
|
71
|
+
associative=False,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def convert_filter_to_smoother_state(
|
|
76
|
+
filter_state: ParticleFilterState,
|
|
77
|
+
resampling: Resampling,
|
|
78
|
+
n_smoother_particles: int,
|
|
79
|
+
model_inputs: ArrayTreeLike | None = None,
|
|
80
|
+
key: KeyArray | None = None,
|
|
81
|
+
) -> ParticleSmootherState:
|
|
82
|
+
"""Convert a particle filter state to a particle smoother state.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
filter_state: Particle filter state.
|
|
86
|
+
resampling: Resampling algorithm to use (e.g., multinomial, systematic).
|
|
87
|
+
n_smoother_particles: Number of smoother samples to draw.
|
|
88
|
+
model_inputs: Only used to create an empty model_inputs tree
|
|
89
|
+
(the values are ignored).
|
|
90
|
+
Useful so that the final smoother state has the same structure as the rest.
|
|
91
|
+
By default, filter_state.model_inputs is used. So this
|
|
92
|
+
is only needed if the smoother model_inputs have a different tree
|
|
93
|
+
structure to filter_state.model_inputs.
|
|
94
|
+
key: JAX random key.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Particle smoother state. Note that the model_inputs are set to dummy values.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If key is None.
|
|
101
|
+
"""
|
|
102
|
+
if key is None:
|
|
103
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
104
|
+
|
|
105
|
+
if model_inputs is None:
|
|
106
|
+
model_inputs = filter_state.model_inputs
|
|
107
|
+
|
|
108
|
+
dummy_model_inputs = dummy_tree_like(model_inputs)
|
|
109
|
+
|
|
110
|
+
key, resampling_key = random.split(key)
|
|
111
|
+
indices = resampling(resampling_key, filter_state.log_weights, n_smoother_particles)
|
|
112
|
+
|
|
113
|
+
return ParticleSmootherState(
|
|
114
|
+
key=cast(KeyArray, key),
|
|
115
|
+
particles=jax.tree.map(lambda z: z[indices], filter_state.particles),
|
|
116
|
+
ancestor_indices=filter_state.ancestor_indices[indices],
|
|
117
|
+
model_inputs=dummy_model_inputs,
|
|
118
|
+
log_weights=-jnp.log(n_smoother_particles) * jnp.ones(n_smoother_particles),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def smoother_prepare(
|
|
123
|
+
filter_state: ParticleFilterState,
|
|
124
|
+
model_inputs: ArrayTreeLike,
|
|
125
|
+
key: KeyArray | None = None,
|
|
126
|
+
) -> ParticleSmootherState:
|
|
127
|
+
"""Prepare a state for a particle smoother step.
|
|
128
|
+
|
|
129
|
+
Note that the model_inputs here are different to filter_state.model_inputs.
|
|
130
|
+
The model_inputs required here are for the transition from t to t+1.
|
|
131
|
+
filter_state.model_inputs represents the transition from t-1 to t.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
filter_state: Particle filter state from time t.
|
|
135
|
+
model_inputs: Model inputs for the transition from t to t+1.
|
|
136
|
+
key: JAX random key.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Prepared state for the particle smoother.
|
|
140
|
+
"""
|
|
141
|
+
if key is None:
|
|
142
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
143
|
+
|
|
144
|
+
model_inputs = jax.tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
145
|
+
|
|
146
|
+
return ParticleSmootherState(
|
|
147
|
+
key,
|
|
148
|
+
filter_state.particles,
|
|
149
|
+
filter_state.ancestor_indices,
|
|
150
|
+
model_inputs,
|
|
151
|
+
filter_state.log_weights,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def smoother_combine(
|
|
156
|
+
state_1: ParticleSmootherState,
|
|
157
|
+
state_2: ParticleSmootherState,
|
|
158
|
+
backward_sampling_fn: BackwardSampling,
|
|
159
|
+
log_potential: LogPotential,
|
|
160
|
+
) -> ParticleSmootherState:
|
|
161
|
+
"""Combine next smoother state with state prepared with latest model inputs.
|
|
162
|
+
|
|
163
|
+
Remember smoothing iterates backwards in time.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
state_1: State prepared with model inputs at time t.
|
|
167
|
+
state_2: Smoother state at time t + 1.
|
|
168
|
+
backward_sampling_fn: Function to perform backward sampling from the joint distribution.
|
|
169
|
+
log_potential: Function to compute log potential.
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Combined particle smoother state.
|
|
173
|
+
Contains particles, the original ancestor indices of the particles, and model inputs.
|
|
174
|
+
"""
|
|
175
|
+
new_particles_1, ancestors_1 = backward_sampling_fn(
|
|
176
|
+
state_1.key,
|
|
177
|
+
x0_all=state_1.particles,
|
|
178
|
+
x1_all=state_2.particles,
|
|
179
|
+
log_weight_x0_all=state_1.log_weights,
|
|
180
|
+
log_density=lambda s1, s2: log_potential(s1, s2, state_2.model_inputs),
|
|
181
|
+
x1_ancestor_indices=state_2.ancestor_indices,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
n_particles = len(ancestors_1)
|
|
185
|
+
log_weights = -jnp.log(n_particles) * jnp.ones(n_particles)
|
|
186
|
+
new_state = ParticleSmootherState(
|
|
187
|
+
key=state_1.key,
|
|
188
|
+
particles=new_particles_1,
|
|
189
|
+
ancestor_indices=state_1.ancestor_indices[ancestors_1],
|
|
190
|
+
model_inputs=state_1.model_inputs,
|
|
191
|
+
log_weights=log_weights,
|
|
192
|
+
)
|
|
193
|
+
return new_state
|