cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
  29. cuthbertlib/discrete/__init__.py +0 -0
  30. cuthbertlib/discrete/filtering.py +49 -0
  31. cuthbertlib/discrete/smoothing.py +35 -0
  32. cuthbertlib/kalman/__init__.py +4 -0
  33. cuthbertlib/kalman/filtering.py +213 -0
  34. cuthbertlib/kalman/generate.py +85 -0
  35. cuthbertlib/kalman/sampling.py +68 -0
  36. cuthbertlib/kalman/smoothing.py +121 -0
  37. cuthbertlib/linalg/__init__.py +7 -0
  38. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  39. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  40. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  41. cuthbertlib/linalg/tria.py +21 -0
  42. cuthbertlib/linearize/__init__.py +7 -0
  43. cuthbertlib/linearize/log_density.py +175 -0
  44. cuthbertlib/linearize/moments.py +94 -0
  45. cuthbertlib/linearize/taylor.py +83 -0
  46. cuthbertlib/quadrature/__init__.py +4 -0
  47. cuthbertlib/quadrature/common.py +102 -0
  48. cuthbertlib/quadrature/cubature.py +73 -0
  49. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  50. cuthbertlib/quadrature/linearize.py +143 -0
  51. cuthbertlib/quadrature/unscented.py +79 -0
  52. cuthbertlib/quadrature/utils.py +109 -0
  53. cuthbertlib/resampling/__init__.py +3 -0
  54. cuthbertlib/resampling/killing.py +79 -0
  55. cuthbertlib/resampling/multinomial.py +53 -0
  56. cuthbertlib/resampling/protocols.py +92 -0
  57. cuthbertlib/resampling/systematic.py +78 -0
  58. cuthbertlib/resampling/utils.py +82 -0
  59. cuthbertlib/smc/__init__.py +0 -0
  60. cuthbertlib/smc/ess.py +24 -0
  61. cuthbertlib/smc/smoothing/__init__.py +0 -0
  62. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  63. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  64. cuthbertlib/smc/smoothing/protocols.py +44 -0
  65. cuthbertlib/smc/smoothing/tracing.py +45 -0
  66. cuthbertlib/stats/__init__.py +0 -0
  67. cuthbertlib/stats/multivariate_normal.py +102 -0
  68. cuthbert-0.0.1.dist-info/RECORD +0 -12
  69. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  70. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,78 @@
1
+ """Implements systematic resampling."""
2
+
3
+ from functools import partial
4
+
5
+ from jax import numpy as jnp
6
+ from jax import random
7
+ from jax.lax import cond, select
8
+ from jax.scipy.special import logsumexp
9
+
10
+ from cuthbertlib.resampling.protocols import (
11
+ conditional_resampling_decorator,
12
+ resampling_decorator,
13
+ )
14
+ from cuthbertlib.resampling.utils import inverse_cdf
15
+ from cuthbertlib.types import Array, ArrayLike
16
+
17
+ _DESCRIPTION = """
18
+ The Systematic resampling is a variance reduction which places marginally
19
+ uniform samples into the [0, 1] interval but only requires one uniform random.
20
+ """
21
+
22
+
23
+ @partial(resampling_decorator, name="Systematic", desc=_DESCRIPTION)
24
+ def resampling(key: Array, logits: ArrayLike, n: int) -> Array:
25
+ us = (random.uniform(key, ()) + jnp.arange(n)) / n
26
+ return inverse_cdf(us, logits)
27
+
28
+
29
+ @partial(conditional_resampling_decorator, name="Systematic", desc=_DESCRIPTION)
30
+ def conditional_resampling(
31
+ key: Array, logits: ArrayLike, n: int, pivot_in: int, pivot_out: int
32
+ ) -> Array:
33
+ logits = jnp.asarray(logits)
34
+ # FIXME: no need for normalizing in theory
35
+ N = logits.shape[0]
36
+ logits -= logsumexp(logits)
37
+
38
+ # FIXME: this rolling should be done in a single function, but this is killing me.
39
+ arange = jnp.arange(N)
40
+ logits = jnp.roll(logits, -pivot_out)
41
+ arange = jnp.roll(arange, -pivot_out)
42
+
43
+ idx = conditional_resampling_0_to_0(key, logits, n)
44
+ idx = arange[idx]
45
+ idx = jnp.roll(idx, pivot_in)
46
+ return idx
47
+
48
+
49
+ def conditional_resampling_0_to_0(
50
+ key: Array,
51
+ logits: ArrayLike,
52
+ n: int,
53
+ ) -> Array:
54
+ logits = jnp.asarray(logits)
55
+
56
+ N = logits.shape[0]
57
+ weights = jnp.exp(logits - logsumexp(logits))
58
+ tmp = n * weights[0]
59
+ tmp_floor = jnp.floor(tmp)
60
+
61
+ U, V, W = random.uniform(key, (3,))
62
+
63
+ def _otherwise():
64
+ rem = tmp - tmp_floor
65
+ p_cond = rem * (tmp_floor + 1) / tmp
66
+ return select(V < p_cond, rem * U, rem + (1.0 - rem) * U)
67
+
68
+ uniform = cond(tmp <= 1, lambda: tmp * U, _otherwise)
69
+
70
+ linspace = (jnp.arange(n) + uniform) / n
71
+ idx = inverse_cdf(linspace, logits)
72
+
73
+ n_zero = jnp.sum(idx == 0)
74
+ zero_loc = jnp.flatnonzero(idx == 0, size=n, fill_value=-1)
75
+ roll_idx = jnp.floor(n_zero * W).astype(int)
76
+
77
+ idx = select(n_zero == 1, idx, jnp.roll(idx, -zero_loc[roll_idx]))
78
+ return jnp.clip(idx, 0, N - 1)
@@ -0,0 +1,82 @@
1
+ """Utility functions (inverse CDF sampling) for resampling algorithms."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import numba as nb
6
+ import numpy as np
7
+ from jax.lax import platform_dependent
8
+ from jax.scipy.special import logsumexp
9
+
10
+ from cuthbertlib.types import Array, ArrayLike
11
+
12
+
13
+ @jax.jit
14
+ def inverse_cdf(sorted_uniforms: ArrayLike, logits: ArrayLike) -> Array:
15
+ """Inverse CDF sampling for resampling algorithms.
16
+
17
+ The implementation branches depending on the platform being CPU or GPU (and other parallel envs)
18
+ 1. The CPU implementation is a numba-compiled specialized searchsorted(arr, vals) for *sorted* uniforms
19
+ which is guaranteed to run in O(N + M) where N is the size of logits and M that of sorted_uniforms.
20
+ This could be replaced mutatis mutandis by np.searchsorted (not jnp.searchsorted!) but the latter
21
+ does not guarantee this execution time.
22
+ 2. On GPU, we use searchsorted with a sorting strategy: this proceeds by merge (arg) sorting,
23
+ which works in log(N+M) and is efficient for large arrays of sorted uniforms, typically our setting.
24
+
25
+ Args:
26
+ sorted_uniforms: Sorted uniforms.
27
+ logits: Log-weights, possibly un-normalized.
28
+
29
+ Returns:
30
+ Indices of the particles to be resampled.
31
+ """
32
+ weights = jnp.exp(logits - logsumexp(logits))
33
+ return platform_dependent(
34
+ sorted_uniforms, weights, cpu=_inverse_cdf_cpu, default=_inverse_cdf_default
35
+ )
36
+
37
+
38
+ @jax.jit
39
+ def _inverse_cdf_default(sorted_uniforms: ArrayLike, weights: ArrayLike) -> Array:
40
+ weights = jnp.asarray(weights)
41
+ M = weights.shape[0]
42
+ cs = jnp.cumsum(weights)
43
+ idx = jnp.searchsorted(cs, sorted_uniforms, method="sort")
44
+ return jnp.clip(idx, 0, M - 1).astype(
45
+ int
46
+ ) # Ensure indices are integers from the same dtype as basic jax ints.
47
+
48
+
49
+ @jax.jit
50
+ def _inverse_cdf_cpu(sorted_uniforms: ArrayLike, weights: ArrayLike) -> Array:
51
+ sorted_uniforms = jnp.asarray(sorted_uniforms)
52
+ weights = jnp.asarray(weights)
53
+ M = weights.shape[0]
54
+ N = sorted_uniforms.shape[0]
55
+ idx = jnp.zeros(N, dtype=int)
56
+
57
+ def callback(args):
58
+ su, w, idx_ = args
59
+ idx_ = np.array(idx_, dtype=idx.dtype)
60
+ su = np.asarray(su)
61
+ w = np.asarray(w)
62
+ _inverse_cdf_numba(su, w, idx_)
63
+ return idx_
64
+
65
+ idx = jax.pure_callback(
66
+ callback, idx, (sorted_uniforms, weights, idx), vmap_method="sequential"
67
+ )
68
+ return jnp.clip(idx, 0, M - 1)
69
+
70
+
71
+ @nb.njit
72
+ def _inverse_cdf_numba(su, ws, idx):
73
+ j = 0
74
+ s = ws[0]
75
+ M = su.shape[0]
76
+ N = ws.shape[0]
77
+
78
+ for n in range(M):
79
+ while su[n] > s and j < N - 1:
80
+ j += 1
81
+ s += ws[j]
82
+ idx[n] = j
File without changes
cuthbertlib/smc/ess.py ADDED
@@ -0,0 +1,24 @@
1
+ """Importance sampling effective sample size (ESS) computation."""
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from jax import Array
6
+ from jax.typing import ArrayLike
7
+
8
+
9
+ def log_ess(log_weights: ArrayLike) -> Array:
10
+ """Compute the logarithm of the effective sample size (ESS).
11
+
12
+ Args:
13
+ log_weights: Array of log weights for the particles.
14
+ """
15
+ return 2 * jax.nn.logsumexp(log_weights) - jax.nn.logsumexp(2 * log_weights)
16
+
17
+
18
+ def ess(log_weights: ArrayLike) -> Array:
19
+ """Compute the effective sample size (ESS).
20
+
21
+ Args:
22
+ log_weights: Array of log weights for the particles.
23
+ """
24
+ return jnp.exp(log_ess(log_weights))
File without changes
@@ -0,0 +1,111 @@
1
+ """Implements exact backward sampling for smoothing in SMC."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+ from jax import random, vmap
6
+ from jax.scipy.special import logsumexp
7
+
8
+ from cuthbertlib.types import (
9
+ Array,
10
+ ArrayLike,
11
+ ArrayTree,
12
+ ArrayTreeLike,
13
+ KeyArray,
14
+ LogConditionalDensity,
15
+ )
16
+
17
+
18
+ def log_weights_single(
19
+ x0: ArrayTreeLike,
20
+ x1: ArrayTreeLike,
21
+ log_weight_x0: ArrayLike,
22
+ log_density: LogConditionalDensity,
23
+ ) -> Array:
24
+ """Compute smoothing weight for a single sample x0 given a single sample x1.
25
+
26
+ Args:
27
+ x0: The previous state.
28
+ x1: The current state.
29
+ log_weight_x0: The log weights of the previous state.
30
+ log_density: The log density function of x1 given x0.
31
+
32
+ Returns:
33
+ The smoothing weight for sample x0 given a single sample x1.
34
+ """
35
+ return jnp.asarray(log_weight_x0) + log_density(x0, x1)
36
+
37
+
38
+ def log_weights(x0_all, x1, log_weight_x0_all, log_density) -> Array:
39
+ """Compute log smoothing weights over a collection of x0 given a single x1.
40
+
41
+ Args:
42
+ x0_all: Collection of previous states.
43
+ x1: The current state.
44
+ log_weight_x0_all: Collection of log weights of the previous state.
45
+ log_density: The log density function of x1 given x0.
46
+
47
+ Returns:
48
+ Log normalized smoothing weights for each sample x0 given single sample x1.
49
+ """
50
+ backward_log_weights_all = vmap(
51
+ lambda x0, log_weight_x0: log_weights_single(x0, x1, log_weight_x0, log_density)
52
+ )(x0_all, log_weight_x0_all)
53
+
54
+ # Log normalize
55
+ backward_log_weights_all = backward_log_weights_all - logsumexp(
56
+ backward_log_weights_all, axis=0
57
+ )
58
+ return backward_log_weights_all
59
+
60
+
61
+ def simulate_single(
62
+ key, x0_all, x1, log_weight_x0_all, log_density
63
+ ) -> tuple[ArrayTree, Array]:
64
+ """Sample x0 from a collection given a single x1.
65
+
66
+ Args:
67
+ key: A JAX random key.
68
+ x0_all: Collection of previous states.
69
+ x1: The current state.
70
+ log_weight_x0_all: Collection of log weights of the previous state.
71
+ log_density: The log density function of x1 given x0.
72
+
73
+ Returns:
74
+ A single sample x0 from the smoothing trajectory along with its index.
75
+ """
76
+ backward_log_weights_all = log_weights(x0_all, x1, log_weight_x0_all, log_density)
77
+ sampled_index = random.categorical(key, backward_log_weights_all)
78
+ return jax.tree.map(lambda z: z[sampled_index], x0_all), sampled_index
79
+
80
+
81
+ def simulate(
82
+ key: KeyArray,
83
+ x0_all: ArrayTreeLike,
84
+ x1_all: ArrayTreeLike,
85
+ log_weight_x0_all: ArrayLike,
86
+ log_density: LogConditionalDensity,
87
+ x1_ancestor_indices: ArrayLike,
88
+ ) -> tuple[ArrayTree, Array]:
89
+ """Implements the exact backward sampling algorithm for smoothing in SMC.
90
+
91
+ Some arguments are only included for protocol compatibility and not used in this
92
+ implementation.
93
+
94
+ Args:
95
+ key: JAX PRNG key.
96
+ x0_all: A collection of previous states $x_0$.
97
+ x1_all: A collection of current states $x_1$.
98
+ log_weight_x0_all: The log weights of $x_0$.
99
+ log_density: The log density function of $x_1$ given $x_0$.
100
+ x1_ancestor_indices: The ancestor indices of $x_1$. Not used.
101
+
102
+ Returns:
103
+ A collection of samples $x_0$ and their sampled indices.
104
+ """
105
+ log_weight_x0_all = jnp.asarray(log_weight_x0_all)
106
+ n_smoother_particles = jax.tree.leaves(x1_all)[0].shape[0]
107
+ keys = random.split(key, n_smoother_particles)
108
+
109
+ return vmap(
110
+ lambda k, x1: simulate_single(k, x0_all, x1, log_weight_x0_all, log_density)
111
+ )(keys, x1_all)
@@ -0,0 +1,76 @@
1
+ """Implements MCMC backward smoothing in SMC."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+ from jax import random
6
+
7
+ from cuthbertlib.resampling import multinomial
8
+ from cuthbertlib.smc.smoothing.tracing import simulate as ancestor_tracing_simulate
9
+ from cuthbertlib.types import (
10
+ Array,
11
+ ArrayLike,
12
+ ArrayTree,
13
+ ArrayTreeLike,
14
+ KeyArray,
15
+ LogConditionalDensity,
16
+ )
17
+
18
+
19
+ def simulate(
20
+ key: KeyArray,
21
+ x0_all: ArrayTreeLike,
22
+ x1_all: ArrayTreeLike,
23
+ log_weight_x0_all: ArrayLike,
24
+ log_density: LogConditionalDensity,
25
+ x1_ancestor_indices: ArrayLike,
26
+ n_steps: int,
27
+ ) -> tuple[ArrayTree, Array]:
28
+ """Implements the IMH algorithm for smoothing in SMC.
29
+
30
+ Args:
31
+ key: JAX PRNG key.
32
+ x0_all: A collection of previous states $x_0$.
33
+ x1_all: A collection of current states $x_1$.
34
+ log_weight_x0_all: The log weights of $x_0$.
35
+ log_density: The log density function of $x_1$ given $x_0$.
36
+ x1_ancestor_indices: The ancestor indices of $x_1$.
37
+ n_steps: Number of MCMC steps to perform.
38
+
39
+ Returns:
40
+ A collection of samples $x_0$ and their sampled indices.
41
+
42
+ References:
43
+ https://arxiv.org/abs/2207.00976
44
+ """
45
+ key, subkey = random.split(key)
46
+ x0_init, x1_ancestor_indices = ancestor_tracing_simulate(
47
+ subkey, x0_all, x1_all, log_weight_x0_all, log_density, x1_ancestor_indices
48
+ )
49
+ n_samples = x1_ancestor_indices.shape[0]
50
+
51
+ keys = random.split(key, (n_steps * 2)).reshape((n_steps, 2))
52
+
53
+ def body(carry, keys_t):
54
+ # IMH proposal
55
+ idx, x0_res, idx_log_p = carry
56
+ key_prop, key_acc = keys_t
57
+
58
+ prop_idx = multinomial.resampling(key_prop, log_weight_x0_all, n_samples)
59
+ x0_prop = jax.tree.map(lambda z: z[prop_idx], x0_all)
60
+ prop_log_p = jax.vmap(log_density)(x0_prop, x1_all)
61
+
62
+ log_alpha = prop_log_p - idx_log_p
63
+
64
+ lu = jnp.log(random.uniform(key_acc, (n_samples,)))
65
+ acc = lu < log_alpha
66
+
67
+ idx: Array = jnp.where(acc, prop_idx, idx)
68
+ x0_res: ArrayTreeLike = jax.tree.map(lambda z: z[idx], x0_all)
69
+ idx_log_p: Array = jnp.where(acc, prop_log_p, idx_log_p)
70
+ return (idx, x0_res, idx_log_p), None
71
+
72
+ x0_init = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
73
+ init_log_p = jax.vmap(log_density)(x1_all, x0_init)
74
+ init = (x1_ancestor_indices, x0_init, init_log_p)
75
+ (out_index, out_samples, _), _ = jax.lax.scan(body, init, keys)
76
+ return out_samples, out_index
@@ -0,0 +1,44 @@
1
+ """Shared protocols for backward smoothing functions in SMC."""
2
+
3
+ from typing import Protocol, runtime_checkable
4
+
5
+ from cuthbertlib.types import (
6
+ Array,
7
+ ArrayLike,
8
+ ArrayTree,
9
+ ArrayTreeLike,
10
+ KeyArray,
11
+ LogConditionalDensity,
12
+ )
13
+
14
+
15
+ @runtime_checkable
16
+ class BackwardSampling(Protocol):
17
+ """Protocol for backward sampling functions."""
18
+
19
+ def __call__(
20
+ self,
21
+ key: KeyArray,
22
+ x0_all: ArrayTreeLike,
23
+ x1_all: ArrayTreeLike,
24
+ log_weight_x0_all: ArrayLike,
25
+ log_density: LogConditionalDensity,
26
+ x1_ancestor_indices: ArrayLike,
27
+ ) -> tuple[ArrayTree, Array]:
28
+ """Performs a backward sampling step.
29
+
30
+ Samples a collection of $x_0$ that combine with the provided $x_1$ to
31
+ give a collection of pairs $(x_0, x_1)$ from the smoothing distribution.
32
+
33
+ Args:
34
+ key: JAX PRNG key.
35
+ x0_all: A collection of previous states $x_0$.
36
+ x1_all: A collection of current states $x_1$.
37
+ log_weight_x0_all: The log weights of $x_0$.
38
+ log_density: The log density function of $x_1$ given $x_0$.
39
+ x1_ancestor_indices: The ancestor indices of $x_1$.
40
+
41
+ Returns:
42
+ A collection of samples $x_0$ and their sampled indices.
43
+ """
44
+ ...
@@ -0,0 +1,45 @@
1
+ """Implements the ancestor/genealogy tracing algorithm for smoothing in SMC."""
2
+
3
+ import jax
4
+ from jax import numpy as jnp
5
+
6
+ from cuthbertlib.types import (
7
+ Array,
8
+ ArrayLike,
9
+ ArrayTree,
10
+ ArrayTreeLike,
11
+ KeyArray,
12
+ LogConditionalDensity,
13
+ )
14
+
15
+
16
+ def simulate(
17
+ key: KeyArray,
18
+ x0_all: ArrayTreeLike,
19
+ x1_all: ArrayTreeLike,
20
+ log_weight_x0_all: ArrayLike,
21
+ log_density: LogConditionalDensity,
22
+ x1_ancestor_indices: ArrayLike,
23
+ ) -> tuple[ArrayTree, Array]:
24
+ """Implements the ancestor/genealogy tracing algorithm for smoothing in SMC.
25
+
26
+ Some arguments are only included for protocol compatibility and not used in this
27
+ implementation.
28
+
29
+ Args:
30
+ key: JAX PRNG key. Not used
31
+ x0_all: A collection of previous states $x_0$.
32
+ x1_all: A collection of current states $x_1$. Not used.
33
+ log_weight_x0_all: The log weights of $x_0$. Not used.
34
+ log_density: The log density function of $x_1$ given $x_0$. Not used.
35
+ x1_ancestor_indices: The ancestor indices of $x_1$.
36
+
37
+ Returns:
38
+ A collection of samples $x_0$ and their sampled indices.
39
+
40
+ References:
41
+ https://arxiv.org/abs/2207.00976
42
+ """
43
+ x1_ancestor_indices = jnp.asarray(x1_ancestor_indices)
44
+ x0 = jax.tree.map(lambda z: z[x1_ancestor_indices], x0_all)
45
+ return x0, x1_ancestor_indices
File without changes
@@ -0,0 +1,102 @@
1
+ """Multivariate normal distribution functions with chol_cov input."""
2
+
3
+ from functools import partial
4
+
5
+ import numpy as np
6
+ from jax import lax
7
+ from jax import numpy as jnp
8
+ from jax._src.numpy.util import promote_dtypes_inexact
9
+ from jax._src.typing import Array, ArrayLike
10
+
11
+ from cuthbertlib.linalg import collect_nans_chol
12
+
13
+
14
+ def logpdf(
15
+ x: ArrayLike, mean: ArrayLike, chol_cov: ArrayLike, nan_support: bool = True
16
+ ) -> Array:
17
+ """Multivariate normal log probability distribution function with chol_cov input.
18
+
19
+ Here `chol_cov` is the (generalized) Cholesky factor of the covariance matrix.
20
+ Modified version of `jax.scipy.stats.multivariate_normal.logpdf` which takes
21
+ the full covariance matrix as input.
22
+
23
+ Args:
24
+ x: Value at which to evaluate the PDF.
25
+ mean: Mean of the distribution.
26
+ chol_cov: Generalized Cholesky factor of the covariance matrix of the distribution.
27
+ nan_support: If `True`, ignores NaNs in `x` by projecting the distribution onto the
28
+ lower-dimensional subspace spanned by the non-NaN entries of `x`. Note that
29
+ `nan_support=True` uses the [tria][cuthbertlib.linalg.tria] operation (QR
30
+ decomposition), and therefore increases the internal complexity of the function
31
+ from $O(n^2)$ to $O(n^3)$, where $n$ is the dimension of `x`.
32
+
33
+ Returns:
34
+ Array of logpdf values.
35
+ """
36
+ x, mean, chol_cov = promote_dtypes_inexact(x, mean, chol_cov)
37
+
38
+ # If nan_support is True, we need to collect the NaNs at the top of the covariance matrix
39
+ # this uses a QR decomposition so is more expensive
40
+ if nan_support:
41
+ flag = jnp.isnan(x)
42
+ flag, chol_cov, x, mean = collect_nans_chol(flag, chol_cov, x, mean)
43
+ mean = jnp.asarray(mean)
44
+ x = jnp.asarray(x)
45
+ chol_cov = jnp.asarray(chol_cov)
46
+
47
+ if not mean.shape and not np.shape(x):
48
+ # Both mean and x are scalars
49
+ return -1 / 2 * jnp.square(x - mean) / chol_cov**2 - 1 / 2 * (
50
+ jnp.log(2 * np.pi) + 2 * jnp.log(chol_cov)
51
+ )
52
+ else:
53
+ n = mean.shape[-1] if mean.shape else x.shape[-1]
54
+ if not np.shape(chol_cov):
55
+ y = x - mean
56
+ return -1 / 2 * jnp.einsum("...i,...i->...", y, y) / chol_cov**2 - n / 2 * (
57
+ jnp.log(2 * np.pi) + 2 * jnp.log(chol_cov)
58
+ )
59
+ elif chol_cov.ndim == 1:
60
+ y = (x - mean) / chol_cov
61
+ return (
62
+ -1 / 2 * jnp.einsum("...i,...i->...", y, y)
63
+ - n / 2 * jnp.log(2 * np.pi)
64
+ - jnp.log(jnp.abs(chol_cov)).sum(-1)
65
+ )
66
+ else:
67
+ if chol_cov.ndim < 2 or chol_cov.shape[-2:] != (n, n):
68
+ raise ValueError("multivariate_normal.logpdf got incompatible shapes")
69
+ y = jnp.vectorize(
70
+ partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
71
+ signature="(n,n),(n)->(n)",
72
+ )(chol_cov, x - mean)
73
+ return (
74
+ -1 / 2 * jnp.einsum("...i,...i->...", y, y)
75
+ - n / 2 * jnp.log(2 * np.pi)
76
+ - jnp.log(jnp.abs(chol_cov.diagonal(axis1=-1, axis2=-2))).sum(-1)
77
+ )
78
+
79
+
80
+ def pdf(
81
+ x: ArrayLike, mean: ArrayLike, chol_cov: ArrayLike, nan_support: bool = True
82
+ ) -> Array:
83
+ """Multivariate normal probability distribution function with chol_cov input.
84
+
85
+ Here `chol_cov` is the (generalized) Cholesky factor of the covariance matrix.
86
+ Modified version of `jax.scipy.stats.multivariate_normal.pdf` which takes
87
+ the full covariance matrix as input.
88
+
89
+ Args:
90
+ x: Value at which to evaluate the PDF.
91
+ mean: Mean of the distribution.
92
+ chol_cov: Generalized Cholesky factor of the covariance matrix of the distribution.
93
+ nan_support: If `True`, ignores NaNs in `x` by projecting the distribution onto the
94
+ lower-dimensional subspace spanned by the non-NaN entries of `x`. Note that
95
+ `nan_support=True` uses the [tria][cuthbertlib.linalg.tria] operation (QR
96
+ decomposition), and therefore increases the internal complexity of the function
97
+ from $O(n^2)$ to $O(n^3)$, where $n$ is the dimension of `x`.
98
+
99
+ Returns:
100
+ Array of pdf values.
101
+ """
102
+ return lax.exp(logpdf(x, mean, chol_cov, nan_support))
@@ -1,12 +0,0 @@
1
- cuthbert/__init__.py,sha256=60_FB1nfduZLPphbjEc7WRpebCnVYiwMyKuD-7yZBvw,281
2
- cuthbert/filtering.py,sha256=HaUPJWhBO8P5IWlRYhLwCeOEtlYenAvWjYsQaF_cPX0,2700
3
- cuthbert/inference.py,sha256=u02wVKGu7mIdqn2XhcSZ93xXyPIQkxzSvxJXMH7Zo5k,7600
4
- cuthbert/smoothing.py,sha256=qvNAWYTGGaEUEBp7HAtYLtF1UOQK9E7u3V_l4XgxeaI,4960
5
- cuthbert/utils.py,sha256=0JQgRiyVs4SXZ0ullh4OmAMtyoOtuCzvYuDmpu9jXOE,1023
6
- cuthbert-0.0.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
7
- cuthbertlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- cuthbertlib/types.py,sha256=sgHC0J7W_0wOzMm5dlLZXwJB3W1xvzq2J6cxl6U831w,1020
9
- cuthbert-0.0.1.dist-info/METADATA,sha256=Q-MikVBB27Lne_4q7HvrBVNgln6zY23RdY5PdnIeE4M,7040
10
- cuthbert-0.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
11
- cuthbert-0.0.1.dist-info/top_level.txt,sha256=R7-G6fUQZSMNMcM4-IcpHKtY9CZOQeejEpVW9q-Sarw,21
12
- cuthbert-0.0.1.dist-info/RECORD,,