cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.1.dist-info/RECORD +0 -12
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Implements systematic resampling."""
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
from jax import numpy as jnp
|
|
6
|
+
from jax import random
|
|
7
|
+
from jax.lax import cond, select
|
|
8
|
+
from jax.scipy.special import logsumexp
|
|
9
|
+
|
|
10
|
+
from cuthbertlib.resampling.protocols import (
|
|
11
|
+
conditional_resampling_decorator,
|
|
12
|
+
resampling_decorator,
|
|
13
|
+
)
|
|
14
|
+
from cuthbertlib.resampling.utils import inverse_cdf
|
|
15
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
16
|
+
|
|
17
|
+
_DESCRIPTION = """
|
|
18
|
+
The Systematic resampling is a variance reduction which places marginally
|
|
19
|
+
uniform samples into the [0, 1] interval but only requires one uniform random.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@partial(resampling_decorator, name="Systematic", desc=_DESCRIPTION)
|
|
24
|
+
def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
|
|
25
|
+
us = (random.uniform(key, ()) + jnp.arange(n)) / n
|
|
26
|
+
return inverse_cdf(us, logits)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@partial(conditional_resampling_decorator, name="Systematic", desc=_DESCRIPTION)
|
|
30
|
+
def conditional_resampling(
|
|
31
|
+
key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
|
|
32
|
+
) -> Array:
|
|
33
|
+
logits = jnp.asarray(logits)
|
|
34
|
+
# FIXME: no need for normalizing in theory
|
|
35
|
+
N = logits.shape[0]
|
|
36
|
+
logits -= logsumexp(logits)
|
|
37
|
+
|
|
38
|
+
# FIXME: this rolling should be done in a single function, but this is killing me.
|
|
39
|
+
arange = jnp.arange(N)
|
|
40
|
+
logits = jnp.roll(logits, -pivot_out)
|
|
41
|
+
arange = jnp.roll(arange, -pivot_out)
|
|
42
|
+
|
|
43
|
+
idx = conditional_resampling_0_to_0(key, logits, n)
|
|
44
|
+
idx = arange[idx]
|
|
45
|
+
idx = jnp.roll(idx, pivot_in)
|
|
46
|
+
return idx
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def conditional_resampling_0_to_0(
|
|
50
|
+
key: Array,
|
|
51
|
+
logits: ArrayLike,
|
|
52
|
+
n: int,
|
|
53
|
+
) -> Array:
|
|
54
|
+
logits = jnp.asarray(logits)
|
|
55
|
+
|
|
56
|
+
N = logits.shape[0]
|
|
57
|
+
weights = jnp.exp(logits - logsumexp(logits))
|
|
58
|
+
tmp = n * weights[0]
|
|
59
|
+
tmp_floor = jnp.floor(tmp)
|
|
60
|
+
|
|
61
|
+
U, V, W = random.uniform(key, (3,))
|
|
62
|
+
|
|
63
|
+
def _otherwise():
|
|
64
|
+
rem = tmp - tmp_floor
|
|
65
|
+
p_cond = rem * (tmp_floor + 1) / tmp
|
|
66
|
+
return select(V < p_cond, rem * U, rem + (1.0 - rem) * U)
|
|
67
|
+
|
|
68
|
+
uniform = cond(tmp <= 1, lambda: tmp * U, _otherwise)
|
|
69
|
+
|
|
70
|
+
linspace = (jnp.arange(n) + uniform) / n
|
|
71
|
+
idx = inverse_cdf(linspace, logits)
|
|
72
|
+
|
|
73
|
+
n_zero = jnp.sum(idx == 0)
|
|
74
|
+
zero_loc = jnp.flatnonzero(idx == 0, size=n, fill_value=-1)
|
|
75
|
+
roll_idx = jnp.floor(n_zero * W).astype(int)
|
|
76
|
+
|
|
77
|
+
idx = select(n_zero == 1, idx, jnp.roll(idx, -zero_loc[roll_idx]))
|
|
78
|
+
return jnp.clip(idx, 0, N - 1)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Utility functions (inverse CDF sampling) for resampling algorithms."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import numba as nb
|
|
6
|
+
import numpy as np
|
|
7
|
+
from jax.lax import platform_dependent
|
|
8
|
+
from jax.scipy.special import logsumexp
|
|
9
|
+
|
|
10
|
+
from cuthbertlib.types import Array, ArrayLike
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@jax.jit
|
|
14
|
+
def inverse_cdf(sorted_uniforms: ArrayLike, logits: ArrayLike) -> Array:
|
|
15
|
+
"""Inverse CDF sampling for resampling algorithms.
|
|
16
|
+
|
|
17
|
+
The implementation branches depending on the platform being CPU or GPU (and other parallel envs)
|
|
18
|
+
1. The CPU implementation is a numba-compiled specialized searchsorted(arr, vals) for *sorted* uniforms
|
|
19
|
+
which is guaranteed to run in O(N + M) where N is the size of logits and M that of sorted_uniforms.
|
|
20
|
+
This could be replaced mutatis mutandis by np.searchsorted (not jnp.searchsorted!) but the latter
|
|
21
|
+
does not guarantee this execution time.
|
|
22
|
+
2. On GPU, we use searchsorted with a sorting strategy: this proceeds by merge (arg) sorting,
|
|
23
|
+
which works in log(N+M) and is efficient for large arrays of sorted uniforms, typically our setting.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
sorted_uniforms: Sorted uniforms.
|
|
27
|
+
logits: Log-weights, possibly un-normalized.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Indices of the particles to be resampled.
|
|
31
|
+
"""
|
|
32
|
+
weights = jnp.exp(logits - logsumexp(logits))
|
|
33
|
+
return platform_dependent(
|
|
34
|
+
sorted_uniforms, weights, cpu=_inverse_cdf_cpu, default=_inverse_cdf_default
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@jax.jit
|
|
39
|
+
def _inverse_cdf_default(sorted_uniforms: ArrayLike, weights: ArrayLike) -> Array:
|
|
40
|
+
weights = jnp.asarray(weights)
|
|
41
|
+
M = weights.shape[0]
|
|
42
|
+
cs = jnp.cumsum(weights)
|
|
43
|
+
idx = jnp.searchsorted(cs, sorted_uniforms, method="sort")
|
|
44
|
+
return jnp.clip(idx, 0, M - 1).astype(
|
|
45
|
+
int
|
|
46
|
+
) # Ensure indices are integers from the same dtype as basic jax ints.
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@jax.jit
|
|
50
|
+
def _inverse_cdf_cpu(sorted_uniforms: ArrayLike, weights: ArrayLike) -> Array:
|
|
51
|
+
sorted_uniforms = jnp.asarray(sorted_uniforms)
|
|
52
|
+
weights = jnp.asarray(weights)
|
|
53
|
+
M = weights.shape[0]
|
|
54
|
+
N = sorted_uniforms.shape[0]
|
|
55
|
+
idx = jnp.zeros(N, dtype=int)
|
|
56
|
+
|
|
57
|
+
def callback(args):
|
|
58
|
+
su, w, idx_ = args
|
|
59
|
+
idx_ = np.array(idx_, dtype=idx.dtype)
|
|
60
|
+
su = np.asarray(su)
|
|
61
|
+
w = np.asarray(w)
|
|
62
|
+
_inverse_cdf_numba(su, w, idx_)
|
|
63
|
+
return idx_
|
|
64
|
+
|
|
65
|
+
idx = jax.pure_callback(
|
|
66
|
+
callback, idx, (sorted_uniforms, weights, idx), vmap_method="sequential"
|
|
67
|
+
)
|
|
68
|
+
return jnp.clip(idx, 0, M - 1)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@nb.njit
|
|
72
|
+
def _inverse_cdf_numba(su, ws, idx):
|
|
73
|
+
j = 0
|
|
74
|
+
s = ws[0]
|
|
75
|
+
M = su.shape[0]
|
|
76
|
+
N = ws.shape[0]
|
|
77
|
+
|
|
78
|
+
for n in range(M):
|
|
79
|
+
while su[n] > s and j < N - 1:
|
|
80
|
+
j += 1
|
|
81
|
+
s += ws[j]
|
|
82
|
+
idx[n] = j
|
|
File without changes
|
cuthbertlib/smc/ess.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Importance sampling effective sample size (ESS) computation."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax import Array
|
|
6
|
+
from jax.typing import ArrayLike
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def log_ess(log_weights: ArrayLike) -> Array:
|
|
10
|
+
"""Compute the logarithm of the effective sample size (ESS).
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
log_weights: Array of log weights for the particles.
|
|
14
|
+
"""
|
|
15
|
+
return 2 * jax.nn.logsumexp(log_weights) - jax.nn.logsumexp(2 * log_weights)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def ess(log_weights: ArrayLike) -> Array:
|
|
19
|
+
"""Compute the effective sample size (ESS).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
log_weights: Array of log weights for the particles.
|
|
23
|
+
"""
|
|
24
|
+
return jnp.exp(log_ess(log_weights))
|
|
File without changes
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Implements exact backward sampling for smoothing in SMC."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
from jax import random, vmap
|
|
6
|
+
from jax.scipy.special import logsumexp
|
|
7
|
+
|
|
8
|
+
from cuthbertlib.types import (
|
|
9
|
+
Array,
|
|
10
|
+
ArrayLike,
|
|
11
|
+
ArrayTree,
|
|
12
|
+
ArrayTreeLike,
|
|
13
|
+
KeyArray,
|
|
14
|
+
LogConditionalDensity,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def log_weights_single(
|
|
19
|
+
x0: ArrayTreeLike,
|
|
20
|
+
x1: ArrayTreeLike,
|
|
21
|
+
log_weight_x0: ArrayLike,
|
|
22
|
+
log_density: LogConditionalDensity,
|
|
23
|
+
) -> Array:
|
|
24
|
+
"""Compute smoothing weight for a single sample x0 given a single sample x1.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
x0: The previous state.
|
|
28
|
+
x1: The current state.
|
|
29
|
+
log_weight_x0: The log weights of the previous state.
|
|
30
|
+
log_density: The log density function of x1 given x0.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The smoothing weight for sample x0 given a single sample x1.
|
|
34
|
+
"""
|
|
35
|
+
return jnp.asarray(log_weight_x0) + log_density(x0, x1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def log_weights(x0_all, x1, log_weight_x0_all, log_density) -> Array:
|
|
39
|
+
"""Compute log smoothing weights over a collection of x0 given a single x1.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
x0_all: Collection of previous states.
|
|
43
|
+
x1: The current state.
|
|
44
|
+
log_weight_x0_all: Collection of log weights of the previous state.
|
|
45
|
+
log_density: The log density function of x1 given x0.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Log normalized smoothing weights for each sample x0 given single sample x1.
|
|
49
|
+
"""
|
|
50
|
+
backward_log_weights_all = vmap(
|
|
51
|
+
lambda x0, log_weight_x0: log_weights_single(x0, x1, log_weight_x0, log_density)
|
|
52
|
+
)(x0_all, log_weight_x0_all)
|
|
53
|
+
|
|
54
|
+
# Log normalize
|
|
55
|
+
backward_log_weights_all = backward_log_weights_all - logsumexp(
|
|
56
|
+
backward_log_weights_all, axis=0
|
|
57
|
+
)
|
|
58
|
+
return backward_log_weights_all
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def simulate_single(
|
|
62
|
+
key, x0_all, x1, log_weight_x0_all, log_density
|
|
63
|
+
) -> tuple[ArrayTree, Array]:
|
|
64
|
+
"""Sample x0 from a collection given a single x1.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
key: A JAX random key.
|
|
68
|
+
x0_all: Collection of previous states.
|
|
69
|
+
x1: The current state.
|
|
70
|
+
log_weight_x0_all: Collection of log weights of the previous state.
|
|
71
|
+
log_density: The log density function of x1 given x0.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
A single sample x0 from the smoothing trajectory along with its index.
|
|
75
|
+
"""
|
|
76
|
+
backward_log_weights_all = log_weights(x0_all, x1, log_weight_x0_all, log_density)
|
|
77
|
+
sampled_index = random.categorical(key, backward_log_weights_all)
|
|
78
|
+
return jax.tree.map(lambda z: z[sampled_index], x0_all), sampled_index
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def simulate(
|
|
82
|
+
key: KeyArray,
|
|
83
|
+
x0_all: ArrayTreeLike,
|
|
84
|
+
x1_all: ArrayTreeLike,
|
|
85
|
+
log_weight_x0_all: ArrayLike,
|
|
86
|
+
log_density: LogConditionalDensity,
|
|
87
|
+
x1_ancestor_indices: ArrayLike,
|
|
88
|
+
) -> tuple[ArrayTree, Array]:
|
|
89
|
+
"""Implements the exact backward sampling algorithm for smoothing in SMC.
|
|
90
|
+
|
|
91
|
+
Some arguments are only included for protocol compatibility and not used in this
|
|
92
|
+
implementation.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
key: JAX PRNG key.
|
|
96
|
+
x0_all: A collection of previous states $x_0$.
|
|
97
|
+
x1_all: A collection of current states $x_1$.
|
|
98
|
+
log_weight_x0_all: The log weights of $x_0$.
|
|
99
|
+
log_density: The log density function of $x_1$ given $x_0$.
|
|
100
|
+
x1_ancestor_indices: The ancestor indices of $x_1$. Not used.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A collection of samples $x_0$ and their sampled indices.
|
|
104
|
+
"""
|
|
105
|
+
log_weight_x0_all = jnp.asarray(log_weight_x0_all)
|
|
106
|
+
n_smoother_particles = jax.tree.leaves(x1_all)[0].shape[0]
|
|
107
|
+
keys = random.split(key, n_smoother_particles)
|
|
108
|
+
|
|
109
|
+
return vmap(
|
|
110
|
+
lambda k, x1: simulate_single(k, x0_all, x1, log_weight_x0_all, log_density)
|
|
111
|
+
)(keys, x1_all)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
"""Implements MCMC backward smoothing in SMC."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
from jax import random
|
|
6
|
+
|
|
7
|
+
from cuthbertlib.resampling import multinomial
|
|
8
|
+
from cuthbertlib.smc.smoothing.tracing import simulate as ancestor_tracing_simulate
|
|
9
|
+
from cuthbertlib.types import (
|
|
10
|
+
Array,
|
|
11
|
+
ArrayLike,
|
|
12
|
+
ArrayTree,
|
|
13
|
+
ArrayTreeLike,
|
|
14
|
+
KeyArray,
|
|
15
|
+
LogConditionalDensity,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def simulate(
|
|
20
|
+
key: KeyArray,
|
|
21
|
+
x0_all: ArrayTreeLike,
|
|
22
|
+
x1_all: ArrayTreeLike,
|
|
23
|
+
log_weight_x0_all: ArrayLike,
|
|
24
|
+
log_density: LogConditionalDensity,
|
|
25
|
+
x1_ancestor_indices: ArrayLike,
|
|
26
|
+
n_steps: int,
|
|
27
|
+
) -> tuple[ArrayTree, Array]:
|
|
28
|
+
"""Implements the IMH algorithm for smoothing in SMC.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
key: JAX PRNG key.
|
|
32
|
+
x0_all: A collection of previous states $x_0$.
|
|
33
|
+
x1_all: A collection of current states $x_1$.
|
|
34
|
+
log_weight_x0_all: The log weights of $x_0$.
|
|
35
|
+
log_density: The log density function of $x_1$ given $x_0$.
|
|
36
|
+
x1_ancestor_indices: The ancestor indices of $x_1$.
|
|
37
|
+
n_steps: Number of MCMC steps to perform.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A collection of samples $x_0$ and their sampled indices.
|
|
41
|
+
|
|
42
|
+
References:
|
|
43
|
+
https://arxiv.org/abs/2207.00976
|
|
44
|
+
"""
|
|
45
|
+
key, subkey = random.split(key)
|
|
46
|
+
x0_init, x1_ancestor_indices = ancestor_tracing_simulate(
|
|
47
|
+
subkey, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices
|
|
48
|
+
)
|
|
49
|
+
n_samples = x1_ancestor_indices.shape[0]
|
|
50
|
+
|
|
51
|
+
keys = random.split(key, (n_steps * 2)).reshape((n_steps, 2))
|
|
52
|
+
|
|
53
|
+
def body(carry, keys_t):
|
|
54
|
+
# IMH proposal
|
|
55
|
+
idx, x0_res, idx_log_p = carry
|
|
56
|
+
key_prop, key_acc = keys_t
|
|
57
|
+
|
|
58
|
+
prop_idx = multinomial.resampling(key_prop, log_weight_x0_all, n_samples)
|
|
59
|
+
x0_prop = jax.tree.map(lambda z: z[prop_idx], x0_all)
|
|
60
|
+
prop_log_p = jax.vmap(log_density)(x0_prop, x1_all)
|
|
61
|
+
|
|
62
|
+
log_alpha = prop_log_p - idx_log_p
|
|
63
|
+
|
|
64
|
+
lu = jnp.log(random.uniform(key_acc, (n_samples,)))
|
|
65
|
+
acc = lu < log_alpha
|
|
66
|
+
|
|
67
|
+
idx: Array = jnp.where(acc, prop_idx, idx)
|
|
68
|
+
x0_res: ArrayTreeLike = jax.tree.map(lambda z: z[idx], x0_all)
|
|
69
|
+
idx_log_p: Array = jnp.where(acc, prop_log_p, idx_log_p)
|
|
70
|
+
return (idx, x0_res, idx_log_p), None
|
|
71
|
+
|
|
72
|
+
x0_init = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
|
|
73
|
+
init_log_p = jax.vmap(log_density)(x1_all, x0_init)
|
|
74
|
+
init = (x1_ancestor_indices, x0_init, init_log_p)
|
|
75
|
+
(out_index, out_samples, _), _ = jax.lax.scan(body, init, keys)
|
|
76
|
+
return out_samples, out_index
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""Shared protocols for backward smoothing functions in SMC."""
|
|
2
|
+
|
|
3
|
+
from typing import Protocol, runtime_checkable
|
|
4
|
+
|
|
5
|
+
from cuthbertlib.types import (
|
|
6
|
+
Array,
|
|
7
|
+
ArrayLike,
|
|
8
|
+
ArrayTree,
|
|
9
|
+
ArrayTreeLike,
|
|
10
|
+
KeyArray,
|
|
11
|
+
LogConditionalDensity,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@runtime_checkable
|
|
16
|
+
class BackwardSampling(Protocol):
|
|
17
|
+
"""Protocol for backward sampling functions."""
|
|
18
|
+
|
|
19
|
+
def __call__(
|
|
20
|
+
self,
|
|
21
|
+
key: KeyArray,
|
|
22
|
+
x0_all: ArrayTreeLike,
|
|
23
|
+
x1_all: ArrayTreeLike,
|
|
24
|
+
log_weight_x0_all: ArrayLike,
|
|
25
|
+
log_density: LogConditionalDensity,
|
|
26
|
+
x1_ancestor_indices: ArrayLike,
|
|
27
|
+
) -> tuple[ArrayTree, Array]:
|
|
28
|
+
"""Performs a backward sampling step.
|
|
29
|
+
|
|
30
|
+
Samples a collection of $x_0$ that combine with the provided $x_1$ to
|
|
31
|
+
give a collection of pairs $(x_0, x_1)$ from the smoothing distribution.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
key: JAX PRNG key.
|
|
35
|
+
x0_all: A collection of previous states $x_0$.
|
|
36
|
+
x1_all: A collection of current states $x_1$.
|
|
37
|
+
log_weight_x0_all: The log weights of $x_0$.
|
|
38
|
+
log_density: The log density function of $x_1$ given $x_0$.
|
|
39
|
+
x1_ancestor_indices: The ancestor indices of $x_1$.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A collection of samples $x_0$ and their sampled indices.
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Implements the ancestor/genealogy tracing algorithm for smoothing in SMC."""
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
|
|
6
|
+
from cuthbertlib.types import (
|
|
7
|
+
Array,
|
|
8
|
+
ArrayLike,
|
|
9
|
+
ArrayTree,
|
|
10
|
+
ArrayTreeLike,
|
|
11
|
+
KeyArray,
|
|
12
|
+
LogConditionalDensity,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def simulate(
|
|
17
|
+
key: KeyArray,
|
|
18
|
+
x0_all: ArrayTreeLike,
|
|
19
|
+
x1_all: ArrayTreeLike,
|
|
20
|
+
log_weight_x0_all: ArrayLike,
|
|
21
|
+
log_density: LogConditionalDensity,
|
|
22
|
+
x1_ancestor_indices: ArrayLike,
|
|
23
|
+
) -> tuple[ArrayTree, Array]:
|
|
24
|
+
"""Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.
|
|
25
|
+
|
|
26
|
+
Some arguments are only included for protocol compatibility and not used in this
|
|
27
|
+
implementation.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
key: JAX PRNG key. Not used
|
|
31
|
+
x0_all: A collection of previous states $x_0$.
|
|
32
|
+
x1_all: A collection of current states $x_1$. Not used.
|
|
33
|
+
log_weight_x0_all: The log weights of $x_0$. Not used.
|
|
34
|
+
log_density: The log density function of $x_1$ given $x_0$. Not used.
|
|
35
|
+
x1_ancestor_indices: The ancestor indices of $x_1$.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
A collection of samples $x_0$ and their sampled indices.
|
|
39
|
+
|
|
40
|
+
References:
|
|
41
|
+
https://arxiv.org/abs/2207.00976
|
|
42
|
+
"""
|
|
43
|
+
x1_ancestor_indices = jnp.asarray(x1_ancestor_indices)
|
|
44
|
+
x0 = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
|
|
45
|
+
return x0, x1_ancestor_indices
|
|
File without changes
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Multivariate normal distribution functions with chol_cov input."""
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from jax import lax
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
from jax._src.numpy.util import promote_dtypes_inexact
|
|
9
|
+
from jax._src.typing import Array, ArrayLike
|
|
10
|
+
|
|
11
|
+
from cuthbertlib.linalg import collect_nans_chol
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def logpdf(
|
|
15
|
+
x: ArrayLike, mean: ArrayLike, chol_cov: ArrayLike, nan_support: bool = True
|
|
16
|
+
) -> Array:
|
|
17
|
+
"""Multivariate normal log probability distribution function with chol_cov input.
|
|
18
|
+
|
|
19
|
+
Here `chol_cov` is the (generalized) Cholesky factor of the covariance matrix.
|
|
20
|
+
Modified version of `jax.scipy.stats.multivariate_normal.logpdf` which takes
|
|
21
|
+
the full covariance matrix as input.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
x: Value at which to evaluate the PDF.
|
|
25
|
+
mean: Mean of the distribution.
|
|
26
|
+
chol_cov: Generalized Cholesky factor of the covariance matrix of the distribution.
|
|
27
|
+
nan_support: If `True`, ignores NaNs in `x` by projecting the distribution onto the
|
|
28
|
+
lower-dimensional subspace spanned by the non-NaN entries of `x`. Note that
|
|
29
|
+
`nan_support=True` uses the [tria][cuthbertlib.linalg.tria] operation (QR
|
|
30
|
+
decomposition), and therefore increases the internal complexity of the function
|
|
31
|
+
from $O(n^2)$ to $O(n^3)$, where $n$ is the dimension of `x`.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Array of logpdf values.
|
|
35
|
+
"""
|
|
36
|
+
x, mean, chol_cov = promote_dtypes_inexact(x, mean, chol_cov)
|
|
37
|
+
|
|
38
|
+
# If nan_support is True, we need to collect the NaNs at the top of the covariance matrix
|
|
39
|
+
# this uses a QR decomposition so is more expensive
|
|
40
|
+
if nan_support:
|
|
41
|
+
flag = jnp.isnan(x)
|
|
42
|
+
flag, chol_cov, x, mean = collect_nans_chol(flag, chol_cov, x, mean)
|
|
43
|
+
mean = jnp.asarray(mean)
|
|
44
|
+
x = jnp.asarray(x)
|
|
45
|
+
chol_cov = jnp.asarray(chol_cov)
|
|
46
|
+
|
|
47
|
+
if not mean.shape and not np.shape(x):
|
|
48
|
+
# Both mean and x are scalars
|
|
49
|
+
return -1 / 2 * jnp.square(x - mean) / chol_cov**2 - 1 / 2 * (
|
|
50
|
+
jnp.log(2 * np.pi) + 2 * jnp.log(chol_cov)
|
|
51
|
+
)
|
|
52
|
+
else:
|
|
53
|
+
n = mean.shape[-1] if mean.shape else x.shape[-1]
|
|
54
|
+
if not np.shape(chol_cov):
|
|
55
|
+
y = x - mean
|
|
56
|
+
return -1 / 2 * jnp.einsum("...i,...i->...", y, y) / chol_cov**2 - n / 2 * (
|
|
57
|
+
jnp.log(2 * np.pi) + 2 * jnp.log(chol_cov)
|
|
58
|
+
)
|
|
59
|
+
elif chol_cov.ndim == 1:
|
|
60
|
+
y = (x - mean) / chol_cov
|
|
61
|
+
return (
|
|
62
|
+
-1 / 2 * jnp.einsum("...i,...i->...", y, y)
|
|
63
|
+
- n / 2 * jnp.log(2 * np.pi)
|
|
64
|
+
- jnp.log(jnp.abs(chol_cov)).sum(-1)
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
if chol_cov.ndim < 2 or chol_cov.shape[-2:] != (n, n):
|
|
68
|
+
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
|
|
69
|
+
y = jnp.vectorize(
|
|
70
|
+
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
|
|
71
|
+
signature="(n,n),(n)->(n)",
|
|
72
|
+
)(chol_cov, x - mean)
|
|
73
|
+
return (
|
|
74
|
+
-1 / 2 * jnp.einsum("...i,...i->...", y, y)
|
|
75
|
+
- n / 2 * jnp.log(2 * np.pi)
|
|
76
|
+
- jnp.log(jnp.abs(chol_cov.diagonal(axis1=-1, axis2=-2))).sum(-1)
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def pdf(
|
|
81
|
+
x: ArrayLike, mean: ArrayLike, chol_cov: ArrayLike, nan_support: bool = True
|
|
82
|
+
) -> Array:
|
|
83
|
+
"""Multivariate normal probability distribution function with chol_cov input.
|
|
84
|
+
|
|
85
|
+
Here `chol_cov` is the (generalized) Cholesky factor of the covariance matrix.
|
|
86
|
+
Modified version of `jax.scipy.stats.multivariate_normal.pdf` which takes
|
|
87
|
+
the full covariance matrix as input.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
x: Value at which to evaluate the PDF.
|
|
91
|
+
mean: Mean of the distribution.
|
|
92
|
+
chol_cov: Generalized Cholesky factor of the covariance matrix of the distribution.
|
|
93
|
+
nan_support: If `True`, ignores NaNs in `x` by projecting the distribution onto the
|
|
94
|
+
lower-dimensional subspace spanned by the non-NaN entries of `x`. Note that
|
|
95
|
+
`nan_support=True` uses the [tria][cuthbertlib.linalg.tria] operation (QR
|
|
96
|
+
decomposition), and therefore increases the internal complexity of the function
|
|
97
|
+
from $O(n^2)$ to $O(n^3)$, where $n$ is the dimension of `x`.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Array of pdf values.
|
|
101
|
+
"""
|
|
102
|
+
return lax.exp(logpdf(x, mean, chol_cov, nan_support))
|
cuthbert-0.0.1.dist-info/RECORD
DELETED
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
cuthbert/__init__.py,sha256=60_FB1nfduZLPphbjEc7WRpebCnVYiwMyKuD-7yZBvw,281
|
|
2
|
-
cuthbert/filtering.py,sha256=HaUPJWhBO8P5IWlRYhLwCeOEtlYenAvWjYsQaF_cPX0,2700
|
|
3
|
-
cuthbert/inference.py,sha256=u02wVKGu7mIdqn2XhcSZ93xXyPIQkxzSvxJXMH7Zo5k,7600
|
|
4
|
-
cuthbert/smoothing.py,sha256=qvNAWYTGGaEUEBp7HAtYLtF1UOQK9E7u3V_l4XgxeaI,4960
|
|
5
|
-
cuthbert/utils.py,sha256=0JQgRiyVs4SXZ0ullh4OmAMtyoOtuCzvYuDmpu9jXOE,1023
|
|
6
|
-
cuthbert-0.0.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
7
|
-
cuthbertlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
cuthbertlib/types.py,sha256=sgHC0J7W_0wOzMm5dlLZXwJB3W1xvzq2J6cxl6U831w,1020
|
|
9
|
-
cuthbert-0.0.1.dist-info/METADATA,sha256=Q-MikVBB27Lne_4q7HvrBVNgln6zY23RdY5PdnIeE4M,7040
|
|
10
|
-
cuthbert-0.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
11
|
-
cuthbert-0.0.1.dist-info/top_level.txt,sha256=R7-G6fUQZSMNMcM4-IcpHKtY9CZOQeejEpVW9q-Sarw,21
|
|
12
|
-
cuthbert-0.0.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|