cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. cuthbert/discrete/__init__.py +2 -0
  2. cuthbert/discrete/filter.py +140 -0
  3. cuthbert/discrete/smoother.py +123 -0
  4. cuthbert/discrete/types.py +53 -0
  5. cuthbert/gaussian/__init__.py +0 -0
  6. cuthbert/gaussian/kalman.py +337 -0
  7. cuthbert/gaussian/moments/__init__.py +11 -0
  8. cuthbert/gaussian/moments/associative_filter.py +180 -0
  9. cuthbert/gaussian/moments/filter.py +95 -0
  10. cuthbert/gaussian/moments/non_associative_filter.py +161 -0
  11. cuthbert/gaussian/moments/smoother.py +118 -0
  12. cuthbert/gaussian/moments/types.py +51 -0
  13. cuthbert/gaussian/taylor/__init__.py +14 -0
  14. cuthbert/gaussian/taylor/associative_filter.py +222 -0
  15. cuthbert/gaussian/taylor/filter.py +129 -0
  16. cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
  17. cuthbert/gaussian/taylor/smoother.py +158 -0
  18. cuthbert/gaussian/taylor/types.py +86 -0
  19. cuthbert/gaussian/types.py +57 -0
  20. cuthbert/gaussian/utils.py +41 -0
  21. cuthbert/smc/__init__.py +0 -0
  22. cuthbert/smc/backward_sampler.py +193 -0
  23. cuthbert/smc/marginal_particle_filter.py +237 -0
  24. cuthbert/smc/particle_filter.py +234 -0
  25. cuthbert/smc/types.py +67 -0
  26. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
  27. cuthbert-0.0.3.dist-info/RECORD +76 -0
  28. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
  29. cuthbertlib/discrete/__init__.py +0 -0
  30. cuthbertlib/discrete/filtering.py +49 -0
  31. cuthbertlib/discrete/smoothing.py +35 -0
  32. cuthbertlib/kalman/__init__.py +4 -0
  33. cuthbertlib/kalman/filtering.py +213 -0
  34. cuthbertlib/kalman/generate.py +85 -0
  35. cuthbertlib/kalman/sampling.py +68 -0
  36. cuthbertlib/kalman/smoothing.py +121 -0
  37. cuthbertlib/linalg/__init__.py +7 -0
  38. cuthbertlib/linalg/collect_nans_chol.py +90 -0
  39. cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
  40. cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
  41. cuthbertlib/linalg/tria.py +21 -0
  42. cuthbertlib/linearize/__init__.py +7 -0
  43. cuthbertlib/linearize/log_density.py +175 -0
  44. cuthbertlib/linearize/moments.py +94 -0
  45. cuthbertlib/linearize/taylor.py +83 -0
  46. cuthbertlib/quadrature/__init__.py +4 -0
  47. cuthbertlib/quadrature/common.py +102 -0
  48. cuthbertlib/quadrature/cubature.py +73 -0
  49. cuthbertlib/quadrature/gauss_hermite.py +62 -0
  50. cuthbertlib/quadrature/linearize.py +143 -0
  51. cuthbertlib/quadrature/unscented.py +79 -0
  52. cuthbertlib/quadrature/utils.py +109 -0
  53. cuthbertlib/resampling/__init__.py +3 -0
  54. cuthbertlib/resampling/killing.py +79 -0
  55. cuthbertlib/resampling/multinomial.py +53 -0
  56. cuthbertlib/resampling/protocols.py +92 -0
  57. cuthbertlib/resampling/systematic.py +78 -0
  58. cuthbertlib/resampling/utils.py +82 -0
  59. cuthbertlib/smc/__init__.py +0 -0
  60. cuthbertlib/smc/ess.py +24 -0
  61. cuthbertlib/smc/smoothing/__init__.py +0 -0
  62. cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
  63. cuthbertlib/smc/smoothing/mcmc.py +76 -0
  64. cuthbertlib/smc/smoothing/protocols.py +44 -0
  65. cuthbertlib/smc/smoothing/tracing.py +45 -0
  66. cuthbertlib/stats/__init__.py +0 -0
  67. cuthbertlib/stats/multivariate_normal.py +102 -0
  68. cuthbert-0.0.1.dist-info/RECORD +0 -12
  69. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
  70. {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ """Implements the marginal particle filter.
2
+
3
+ See [Klaas et. al. (2005)](https://www.cs.ubc.ca/~arnaud/klass_defreitas_doucet_marginalparticlefilterUAI2005.pdf)
4
+ """
5
+
6
+ from functools import partial
7
+ from typing import NamedTuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jax import Array, random, tree
12
+
13
+ from cuthbert.inference import Filter
14
+ from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
15
+ from cuthbert.utils import dummy_tree_like
16
+ from cuthbertlib.resampling import Resampling
17
+ from cuthbertlib.smc.ess import log_ess
18
+ from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
19
+
20
+
21
+ class MarginalParticleFilterState(NamedTuple):
22
+ """Marginal particle filter state."""
23
+
24
+ # no ancestors, as it does not make sense for marginal particle filters
25
+ key: KeyArray
26
+ particles: ArrayTree
27
+ log_weights: Array
28
+ model_inputs: ArrayTree
29
+ log_normalizing_constant: ScalarArray
30
+
31
+
32
+ def build_filter(
33
+ init_sample: InitSample,
34
+ propagate_sample: PropagateSample,
35
+ log_potential: LogPotential,
36
+ n_filter_particles: int,
37
+ resampling_fn: Resampling,
38
+ ess_threshold: float,
39
+ ) -> Filter:
40
+ r"""Builds a marginal particle filter object.
41
+
42
+ Args:
43
+ init_sample: Function to sample from the initial distribution $M_0(x_0)$.
44
+ propagate_sample: Function to sample from the Markov kernel $M_t(x_t \mid x_{t-1})$.
45
+ log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
46
+ n_filter_particles: Number of particles for the filter.
47
+ resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
48
+ ess_threshold: Fraction of particle count specifying when to resample.
49
+ Resampling is triggered when the
50
+ effective sample size (ESS) < ess_threshold * n_filter_particles.
51
+
52
+ Returns:
53
+ Filter object for the particle filter.
54
+ """
55
+ return Filter(
56
+ init_prepare=partial(
57
+ init_prepare,
58
+ init_sample=init_sample,
59
+ log_potential=log_potential,
60
+ n_filter_particles=n_filter_particles,
61
+ ),
62
+ filter_prepare=partial(
63
+ filter_prepare,
64
+ init_sample=init_sample,
65
+ n_filter_particles=n_filter_particles,
66
+ ),
67
+ filter_combine=partial(
68
+ filter_combine,
69
+ propagate_sample=propagate_sample,
70
+ log_potential=log_potential,
71
+ resampling_fn=resampling_fn,
72
+ ess_threshold=ess_threshold,
73
+ ),
74
+ associative=False,
75
+ )
76
+
77
+
78
+ def init_prepare(
79
+ model_inputs: ArrayTreeLike,
80
+ init_sample: InitSample,
81
+ log_potential: LogPotential,
82
+ n_filter_particles: int,
83
+ key: KeyArray | None = None,
84
+ ) -> MarginalParticleFilterState:
85
+ """Prepare the initial state for the marginal particle filter.
86
+
87
+ Args:
88
+ model_inputs: Model inputs.
89
+ init_sample: Function to sample from the initial distribution M_0(x_0).
90
+ log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
91
+ x_{t-1} is None since there is no previous state at t=0.
92
+ n_filter_particles: Number of particles to sample.
93
+ key: JAX random key.
94
+
95
+ Returns:
96
+ Initial state for the filter.
97
+
98
+ Raises:
99
+ ValueError: If `key` is None.
100
+ """
101
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
102
+ if key is None:
103
+ raise ValueError("A JAX PRNG key must be provided.")
104
+
105
+ # Sample
106
+ keys = random.split(key, n_filter_particles)
107
+ particles = jax.vmap(init_sample, (0, None))(keys, model_inputs)
108
+
109
+ # Weight
110
+ log_weights = jax.vmap(log_potential, (None, 0, None))(
111
+ None, particles, model_inputs
112
+ )
113
+
114
+ # Compute the log normalizing constant
115
+ log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(
116
+ n_filter_particles
117
+ )
118
+
119
+ return MarginalParticleFilterState(
120
+ key=key,
121
+ particles=particles,
122
+ log_weights=log_weights,
123
+ model_inputs=model_inputs,
124
+ log_normalizing_constant=log_normalizing_constant,
125
+ )
126
+
127
+
128
+ def filter_prepare(
129
+ model_inputs: ArrayTreeLike,
130
+ init_sample: InitSample,
131
+ n_filter_particles: int,
132
+ key: KeyArray | None = None,
133
+ ) -> MarginalParticleFilterState:
134
+ """Prepare a state for a marginal particle filter step.
135
+
136
+ Args:
137
+ model_inputs: Model inputs.
138
+ init_sample: Function to sample from the initial distribution M_0(x_0).
139
+ Only used to infer particle shapes.
140
+ n_filter_particles: Number of particles for the filter.
141
+ key: JAX random key.
142
+
143
+ Returns:
144
+ Prepared state for the filter.
145
+
146
+ Raises:
147
+ ValueError: If `key` is None.
148
+ """
149
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
150
+ if key is None:
151
+ raise ValueError("A JAX PRNG key must be provided.")
152
+ dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
153
+ particles = tree.map(
154
+ lambda x: jnp.empty((n_filter_particles,) + x.shape), dummy_particle
155
+ )
156
+ particles = dummy_tree_like(particles)
157
+ return MarginalParticleFilterState(
158
+ key=key,
159
+ particles=particles,
160
+ log_weights=jnp.zeros(n_filter_particles),
161
+ model_inputs=model_inputs,
162
+ log_normalizing_constant=jnp.array(0.0),
163
+ )
164
+
165
+
166
+ def filter_combine(
167
+ state_1: MarginalParticleFilterState,
168
+ state_2: MarginalParticleFilterState,
169
+ propagate_sample: PropagateSample,
170
+ log_potential: LogPotential,
171
+ resampling_fn: Resampling,
172
+ ess_threshold: float,
173
+ ) -> MarginalParticleFilterState:
174
+ """Combine previous filter state with the state prepared for the current step.
175
+
176
+ Implements the marginal particle filter update: conditional resampling,
177
+ propagation through state dynamics, and N^2 reweighting based on the
178
+ potential function.
179
+
180
+ Args:
181
+ state_1: Filter state from the previous time step.
182
+ state_2: Filter state prepared for the current step.
183
+ propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
184
+ log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
185
+ resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
186
+ ess_threshold: Fraction of particle count specifying when to resample.
187
+ Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.
188
+
189
+ Returns:
190
+ The filtered state at the current time step.
191
+ """
192
+ N = state_1.log_weights.shape[0]
193
+ keys = random.split(state_1.key, N + 1)
194
+
195
+ # Resample
196
+ prev_log_weights = state_1.log_weights - jax.nn.logsumexp(
197
+ state_1.log_weights
198
+ ) # Ensure normalized
199
+ ancestor_indices, log_weights = jax.lax.cond(
200
+ log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
201
+ lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
202
+ lambda: (jnp.arange(N), state_1.log_weights),
203
+ )
204
+ ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)
205
+
206
+ # Propagate
207
+ next_particles = jax.vmap(propagate_sample, (0, 0, None))(
208
+ keys[1:], ancestors, state_2.model_inputs
209
+ )
210
+
211
+ # N^2 Reweight by comparing all ancestors with all next particles
212
+ log_potential_vmapped = jax.vmap(
213
+ jax.vmap(log_potential, (0, None, None), out_axes=0),
214
+ (None, 0, None),
215
+ out_axes=0,
216
+ )
217
+
218
+ log_potentials = log_potential_vmapped(
219
+ state_1.particles, next_particles, state_2.model_inputs
220
+ )
221
+ next_log_weights = log_potentials + prev_log_weights[None, :]
222
+ next_log_weights = jax.nn.logsumexp(next_log_weights, axis=1)
223
+
224
+ # Compute the log normalizing constant
225
+ logsum_weights = jax.nn.logsumexp(next_log_weights)
226
+ log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights)
227
+ log_normalizing_constant = (
228
+ log_normalizing_constant_incr + state_1.log_normalizing_constant
229
+ )
230
+
231
+ return MarginalParticleFilterState(
232
+ key=state_2.key,
233
+ particles=next_particles,
234
+ log_weights=next_log_weights,
235
+ model_inputs=state_2.model_inputs,
236
+ log_normalizing_constant=log_normalizing_constant,
237
+ )
@@ -0,0 +1,234 @@
1
+ """Implements the generic particle filter.
2
+
3
+ See Algorithm 10.1, [Chopin and Papaspiliopoulos, 2020](https://doi.org/10.1007/978-3-030-47845-2).
4
+ """
5
+
6
+ from functools import partial
7
+ from typing import NamedTuple
8
+
9
+ import jax
10
+ import jax.numpy as jnp
11
+ from jax import Array, random, tree
12
+
13
+ from cuthbert.inference import Filter
14
+ from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
15
+ from cuthbert.utils import dummy_tree_like
16
+ from cuthbertlib.resampling import Resampling
17
+ from cuthbertlib.smc.ess import log_ess
18
+ from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
19
+
20
+
21
+ class ParticleFilterState(NamedTuple):
22
+ """Particle filter state."""
23
+
24
+ key: KeyArray
25
+ particles: ArrayTree
26
+ log_weights: Array
27
+ ancestor_indices: Array
28
+ model_inputs: ArrayTree
29
+ log_normalizing_constant: ScalarArray
30
+
31
+ @property
32
+ def n_particles(self) -> int:
33
+ """Number of particles in the filter state."""
34
+ return self.log_weights.shape[-1]
35
+
36
+
37
+ def build_filter(
38
+ init_sample: InitSample,
39
+ propagate_sample: PropagateSample,
40
+ log_potential: LogPotential,
41
+ n_filter_particles: int,
42
+ resampling_fn: Resampling,
43
+ ess_threshold: float,
44
+ ) -> Filter:
45
+ r"""Builds a particle filter object.
46
+
47
+ Args:
48
+ init_sample: Function to sample from the initial distribution $M_0(x_0)$.
49
+ propagate_sample: Function to sample from the Markov kernel $M_t(x_t \mid x_{t-1})$.
50
+ log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
51
+ n_filter_particles: Number of particles for the filter.
52
+ resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
53
+ ess_threshold: Fraction of particle count specifying when to resample.
54
+ Resampling is triggered when the
55
+ effective sample size (ESS) < ess_threshold * n_filter_particles.
56
+
57
+ Returns:
58
+ Filter object for the particle filter.
59
+ """
60
+ return Filter(
61
+ init_prepare=partial(
62
+ init_prepare,
63
+ init_sample=init_sample,
64
+ log_potential=log_potential,
65
+ n_filter_particles=n_filter_particles,
66
+ ),
67
+ filter_prepare=partial(
68
+ filter_prepare,
69
+ init_sample=init_sample,
70
+ n_filter_particles=n_filter_particles,
71
+ ),
72
+ filter_combine=partial(
73
+ filter_combine,
74
+ propagate_sample=propagate_sample,
75
+ log_potential=log_potential,
76
+ resampling_fn=resampling_fn,
77
+ ess_threshold=ess_threshold,
78
+ ),
79
+ associative=False,
80
+ )
81
+
82
+
83
+ def init_prepare(
84
+ model_inputs: ArrayTreeLike,
85
+ init_sample: InitSample,
86
+ log_potential: LogPotential,
87
+ n_filter_particles: int,
88
+ key: KeyArray | None = None,
89
+ ) -> ParticleFilterState:
90
+ """Prepare the initial state for the particle filter.
91
+
92
+ Args:
93
+ model_inputs: Model inputs.
94
+ init_sample: Function to sample from the initial distribution M_0(x_0).
95
+ log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
96
+ x_{t-1} is None since there is no previous state at t=0.
97
+ n_filter_particles: Number of particles to sample.
98
+ key: JAX random key.
99
+
100
+ Returns:
101
+ Initial state for the particle filter.
102
+
103
+ Raises:
104
+ ValueError: If `key` is None.
105
+ """
106
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
107
+ if key is None:
108
+ raise ValueError("A JAX PRNG key must be provided.")
109
+
110
+ # Sample
111
+ keys = random.split(key, n_filter_particles)
112
+ particles = jax.vmap(init_sample, (0, None))(keys, model_inputs)
113
+
114
+ # Weight
115
+ log_weights = jax.vmap(log_potential, (None, 0, None))(
116
+ None, particles, model_inputs
117
+ )
118
+
119
+ # Compute the log normalizing constant
120
+ log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(
121
+ n_filter_particles
122
+ )
123
+
124
+ return ParticleFilterState(
125
+ key=key,
126
+ particles=particles,
127
+ log_weights=log_weights,
128
+ ancestor_indices=jnp.arange(n_filter_particles),
129
+ model_inputs=model_inputs,
130
+ log_normalizing_constant=log_normalizing_constant,
131
+ )
132
+
133
+
134
+ def filter_prepare(
135
+ model_inputs: ArrayTreeLike,
136
+ init_sample: InitSample,
137
+ n_filter_particles: int,
138
+ key: KeyArray | None = None,
139
+ ) -> ParticleFilterState:
140
+ """Prepare a state for a particle filter step.
141
+
142
+ Args:
143
+ model_inputs: Model inputs.
144
+ init_sample: Function to sample from the initial distribution M_0(x_0).
145
+ Only used to infer particle shapes.
146
+ n_filter_particles: Number of particles for the filter.
147
+ key: JAX random key.
148
+
149
+ Returns:
150
+ Prepared state for the filter.
151
+
152
+ Raises:
153
+ ValueError: If `key` is None.
154
+ """
155
+ model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
156
+ if key is None:
157
+ raise ValueError("A JAX PRNG key must be provided.")
158
+ dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
159
+ particles = tree.map(
160
+ lambda x: jnp.empty((n_filter_particles,) + x.shape), dummy_particle
161
+ )
162
+ particles = dummy_tree_like(particles)
163
+ return ParticleFilterState(
164
+ key=key,
165
+ particles=particles,
166
+ log_weights=jnp.zeros(n_filter_particles),
167
+ ancestor_indices=jnp.arange(n_filter_particles),
168
+ model_inputs=model_inputs,
169
+ log_normalizing_constant=jnp.array(0.0),
170
+ )
171
+
172
+
173
+ def filter_combine(
174
+ state_1: ParticleFilterState,
175
+ state_2: ParticleFilterState,
176
+ propagate_sample: PropagateSample,
177
+ log_potential: LogPotential,
178
+ resampling_fn: Resampling,
179
+ ess_threshold: float,
180
+ ) -> ParticleFilterState:
181
+ """Combine previous filter state with the state prepared for the current step.
182
+
183
+ Implements the particle filter update: conditional resampling,
184
+ propagation through state dynamics, and reweighting based on the potential function.
185
+
186
+ Args:
187
+ state_1: Filter state from the previous time step.
188
+ state_2: Filter state prepared for the current step.
189
+ propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
190
+ log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
191
+ resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
192
+ ess_threshold: Fraction of particle count specifying when to resample.
193
+ Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.
194
+
195
+ Returns:
196
+ The filtered state at the current time step.
197
+ """
198
+ N = state_1.log_weights.shape[0]
199
+ keys = random.split(state_1.key, N + 1)
200
+
201
+ # Resample
202
+ ancestor_indices, log_weights = jax.lax.cond(
203
+ log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
204
+ lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
205
+ lambda: (jnp.arange(N), state_1.log_weights),
206
+ )
207
+ ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)
208
+
209
+ # Propagate
210
+ next_particles = jax.vmap(propagate_sample, (0, 0, None))(
211
+ keys[1:], ancestors, state_2.model_inputs
212
+ )
213
+
214
+ # Reweight
215
+ log_potentials = jax.vmap(log_potential, (0, 0, None))(
216
+ ancestors, next_particles, state_2.model_inputs
217
+ )
218
+ next_log_weights = log_potentials + log_weights
219
+
220
+ # Compute the log normalizing constant
221
+ logsum_weights = jax.nn.logsumexp(next_log_weights)
222
+ log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights)
223
+ log_normalizing_constant = (
224
+ log_normalizing_constant_incr + state_1.log_normalizing_constant
225
+ )
226
+
227
+ return ParticleFilterState(
228
+ state_2.key,
229
+ next_particles,
230
+ next_log_weights,
231
+ ancestor_indices,
232
+ state_2.model_inputs,
233
+ log_normalizing_constant,
234
+ )
cuthbert/smc/types.py ADDED
@@ -0,0 +1,67 @@
1
+ r"""Provides types for representing generic Feynman--Kac models.
2
+
3
+ $$
4
+ \mathbb{Q}_{t}(x_{0:t}) \propto \mathbb{M}_0(x_0) \, G_0(x_0) \prod_{s=1}^{t} M_s(x_s \mid x_{s-1}) \, G_s(x_{s-1}, x_s).
5
+ $$
6
+ """
7
+
8
+ from typing import Protocol
9
+
10
+ from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
11
+
12
+
13
+ class InitSample(Protocol):
14
+ """Protocol for sampling from the initial distribution $M_0(x_0)$."""
15
+
16
+ def __call__(self, key: KeyArray, model_inputs: ArrayTreeLike) -> ArrayTree:
17
+ """Samples from the initial distribution $M_0(x_0)$.
18
+
19
+ Args:
20
+ key: JAX PRNG key.
21
+ model_inputs: Model inputs.
22
+
23
+ Returns:
24
+ A sample $x_0$.
25
+ """
26
+ ...
27
+
28
+
29
+ class PropagateSample(Protocol):
30
+ r"""Protocol for sampling from the Markov kernel $M_t(x_t \mid x_{t-1})$."""
31
+
32
+ def __call__(
33
+ self, key: KeyArray, state: ArrayTreeLike, model_inputs: ArrayTreeLike
34
+ ) -> ArrayTree:
35
+ r"""Samples from the Markov kernel $M_t(x_t \mid x_{t-1})$.
36
+
37
+ Args:
38
+ key: JAX PRNG key.
39
+ state: State at the previous step $x_{t-1}$.
40
+ model_inputs: Model inputs.
41
+
42
+ Returns:
43
+ A sample $x_t$.
44
+ """
45
+ ...
46
+
47
+
48
+ class LogPotential(Protocol):
49
+ r"""Protocol for computing the log potential function $\log G_t(x_{t-1}, x_t)$."""
50
+
51
+ def __call__(
52
+ self,
53
+ state_prev: ArrayTreeLike,
54
+ state: ArrayTreeLike,
55
+ model_inputs: ArrayTreeLike,
56
+ ) -> ScalarArray:
57
+ r"""Computes the log potential function $\log G_t(x_{t-1}, x_t)$.
58
+
59
+ Args:
60
+ state_prev: State at the previous step $x_{t-1}$.
61
+ state: State at the current step $x_{t}$.
62
+ model_inputs: Model inputs.
63
+
64
+ Returns:
65
+ A scalar value $\log G_t(x_{t-1}, x_t)$.
66
+ """
67
+ ...
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cuthbert
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: State-space model inference with JAX
5
5
  Author-email: Sam Duffield <s@mduffield.com>, Sahel Iqbal <sahel13miqbal@proton.me>, Adrien Corenflos <adrien.corenflos.stats@gmail.com>
6
6
  License: Apache-2.0
@@ -33,7 +33,7 @@ Dynamic: license-file
33
33
 
34
34
  <!--intro-start-->
35
35
  <div align="center">
36
- <img src="docs/assets/cuthbert.png" alt="logo"></img>
36
+ <img src="https://raw.githubusercontent.com/state-space-models/cuthbert/main/docs/assets/cuthbert.png" alt="cuthbert logo"></img>
37
37
  </div>
38
38
 
39
39
  A JAX library for state-space model inference
@@ -0,0 +1,76 @@
1
+ cuthbert/__init__.py,sha256=60_FB1nfduZLPphbjEc7WRpebCnVYiwMyKuD-7yZBvw,281
2
+ cuthbert/filtering.py,sha256=HaUPJWhBO8P5IWlRYhLwCeOEtlYenAvWjYsQaF_cPX0,2700
3
+ cuthbert/inference.py,sha256=u02wVKGu7mIdqn2XhcSZ93xXyPIQkxzSvxJXMH7Zo5k,7600
4
+ cuthbert/smoothing.py,sha256=qvNAWYTGGaEUEBp7HAtYLtF1UOQK9E7u3V_l4XgxeaI,4960
5
+ cuthbert/utils.py,sha256=0JQgRiyVs4SXZ0ullh4OmAMtyoOtuCzvYuDmpu9jXOE,1023
6
+ cuthbert/discrete/__init__.py,sha256=AI-GYq2QRyg2hwrEJm-LJRV9aMZ7MY-QkE0wG41Qg_I,104
7
+ cuthbert/discrete/filter.py,sha256=wpefQ3Q6eN6PdXKon93E5tjJLxBxPjqE4ctrd-QBbHA,4674
8
+ cuthbert/discrete/smoother.py,sha256=sMS2NKt-507LYnUSzqd3GMCArcSO2-eDiINX_hvLNzY,4291
9
+ cuthbert/discrete/types.py,sha256=hdfxqBGRza45ni0_Lu2NNvqLga_3o0PQWQgJ08BcxhM,1412
10
+ cuthbert/gaussian/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ cuthbert/gaussian/kalman.py,sha256=WldTKTCzlAWX1btAD6kD63Qoz1RpTufOiq2riT6ND44,11567
12
+ cuthbert/gaussian/types.py,sha256=IjC8AWUpjLMgm1co2sc7knb3Otwz-pnUSJEw8p5ohGc,1725
13
+ cuthbert/gaussian/utils.py,sha256=AnRdw_Ju7rWXpbCaOB0CUkoVctJ8gScchgURv5mQ0BU,1301
14
+ cuthbert/gaussian/moments/__init__.py,sha256=cYUxe1Sksf0qqBYO-zwtKnuoZ_gOE_Hr7aGsrzeNkOQ,327
15
+ cuthbert/gaussian/moments/associative_filter.py,sha256=abFEs2KVbXrnIMUJMVKD32hRP8Q3A4NkQP_VmfSGeaI,6188
16
+ cuthbert/gaussian/moments/filter.py,sha256=8dG_Aqdvs_WC23JGpC3B5mTjG5uJMtd15LxJAjeyddY,3951
17
+ cuthbert/gaussian/moments/non_associative_filter.py,sha256=wV_IBeaA-EnWLnX5Wkp6-Kx9WDnpeVWtdUiqlXeGS1U,5942
18
+ cuthbert/gaussian/moments/smoother.py,sha256=L79R8WPwI30aHyCdZE1paYM9KNZtmQlRsvFhDyFsoxs,4262
19
+ cuthbert/gaussian/moments/types.py,sha256=UXUaeakgyO-WLcY8ECmFcf16RwOQZ6pz0B5qI621xV8,1727
20
+ cuthbert/gaussian/taylor/__init__.py,sha256=xOdFMzy59MfF1WuM7Oy4_H7Hn0TtekyoK-ECM5CVuEE,428
21
+ cuthbert/gaussian/taylor/associative_filter.py,sha256=daIY693BefoeV15T4hnOiJ2_e-eH2bGDrmA4pEwNB0A,8318
22
+ cuthbert/gaussian/taylor/filter.py,sha256=ljlHidVuFAlrVVSfmFM6sfdxQuYsmCVsSMucHCuFd3M,5303
23
+ cuthbert/gaussian/taylor/non_associative_filter.py,sha256=qhT31jWyCct_C0Ae3wNLm3BD75eBYX16xNDaA8uw-wg,9446
24
+ cuthbert/gaussian/taylor/smoother.py,sha256=hDAOTJc71oZZIH_ZRkVDXnbLj8riwNqncRy-N372wXA,5814
25
+ cuthbert/gaussian/taylor/types.py,sha256=h45seSNRUSgT4tpQVjqW_AJny2HCz6kJi9yM7zLEWOc,2733
26
+ cuthbert/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
+ cuthbert/smc/backward_sampler.py,sha256=rFsJvSvx3S2hRl2ojTwvwQqERtu_gaBEWzjQtO2GIXQ,6737
28
+ cuthbert/smc/marginal_particle_filter.py,sha256=91WaqsSmoo3Q5pPanYaLSnCOyKLMWSDH1aV3h1SkO2g,8125
29
+ cuthbert/smc/particle_filter.py,sha256=PebY0rt0C76lAZgoWkpIGzUkR_BxoRMCZtOFsRRVSok,7770
30
+ cuthbert/smc/types.py,sha256=jen13y7RWylzIAmrpPGpp980V7uukG06kkoEkcCGZJA,1841
31
+ cuthbert-0.0.3.dist-info/licenses/LICENSE,sha256=4bgVKT6ovc1ZpaqBTr3nDmQrN0rNe1Yx-WNW_q3Er7s,11390
32
+ cuthbertlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
33
+ cuthbertlib/types.py,sha256=sgHC0J7W_0wOzMm5dlLZXwJB3W1xvzq2J6cxl6U831w,1020
34
+ cuthbertlib/discrete/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
+ cuthbertlib/discrete/filtering.py,sha256=_TMGwAeDXHivANpck3HUXFFm8SUalm72N7NHPzYTan8,1466
36
+ cuthbertlib/discrete/smoothing.py,sha256=FZJkSpUhAo7-iM3FiudKvR5XNvNEFXhfzdkf_kIoNQM,1214
37
+ cuthbertlib/kalman/__init__.py,sha256=Tf2dJBk9qavIe_0OL_TJ2zfCTC0hBJrPS0OMbnx_p_o,221
38
+ cuthbertlib/kalman/filtering.py,sha256=YQvxzsF5XsYz8JEJIoHiETtcNHmcMCM4djanNn-pe4w,6757
39
+ cuthbertlib/kalman/generate.py,sha256=AeF-NXIpxN-E6YmiKIaUg4EK5wzWLkTtdg3K9RCemg4,3148
40
+ cuthbertlib/kalman/sampling.py,sha256=Cx4EPCXwMJlPt_48aLsCowk3t70dfA0Fp_lDUQih4ok,1914
41
+ cuthbertlib/kalman/smoothing.py,sha256=iCd5hEVrwVZ3Q2aQtJeNe1IKbpcXD7E1fqHy1P6bX_U,3928
42
+ cuthbertlib/linalg/__init__.py,sha256=ldnznmTBCJeyFfiz1nLQ1hN0X9PwcjnQOAXjpmL6Ag8,284
43
+ cuthbertlib/linalg/collect_nans_chol.py,sha256=7pC7lh5Oux3IiNndIOSYLUdMTAjrVIwkeK52hGnbmJg,3137
44
+ cuthbertlib/linalg/marginal_sqrt_cov.py,sha256=Hq0ofXdMzBANIhfwqKXUSU0IL6LwGZIvU3Uq7EfaZ7E,1188
45
+ cuthbertlib/linalg/symmetric_inv_sqrt.py,sha256=8CWJzaS9VGl3tlJHdw7B7q-DiwbWA5aqAab4i5ZJKU8,5043
46
+ cuthbertlib/linalg/tria.py,sha256=twzyyachwNGDlHA6i6mBl504vVqOYjeC_9a8fa_X_X4,571
47
+ cuthbertlib/linearize/__init__.py,sha256=JzeMqWch_mXJ4a1WaAYuQqQg7Z4gCvTJpZas4RqXwSE,287
48
+ cuthbertlib/linearize/log_density.py,sha256=VUwHFCWsvZwU22aPWD_YrUb3aGOQFP_NhyUqDDYmpJo,6073
49
+ cuthbertlib/linearize/moments.py,sha256=DiU8osX4x9htR3G_cxqMN9EcfBJ0xxSoR0Rn-2bY5WA,3205
50
+ cuthbertlib/linearize/taylor.py,sha256=9LUzoo0wuU5mReO4DUggwyFbSJpMuZ7UMqpmvnJyyrA,3238
51
+ cuthbertlib/quadrature/__init__.py,sha256=79moFPswxvR9InZrL3RysITfjVnYDUpVa5E2CJcdvW4,275
52
+ cuthbertlib/quadrature/common.py,sha256=HpNFVXgP_L4vzqpNtRE6XCmiFrBAsIj-9KThAENUSbM,2743
53
+ cuthbertlib/quadrature/cubature.py,sha256=1yiGLaij6JDPtttFBbqgNuvG6ZUgdHSbWek8d8HGmyw,1877
54
+ cuthbertlib/quadrature/gauss_hermite.py,sha256=LuapZIklRb7CpZI5StJ9BD7oBevD0JU4cEEkqqM10dE,1740
55
+ cuthbertlib/quadrature/linearize.py,sha256=y_GcnER8Btd-mn2GN52Wowq893PMcSxcJtnDyxbKNoA,5403
56
+ cuthbertlib/quadrature/unscented.py,sha256=SMI51HCESHLKEW4GGrYtVCmejkiP8WUxWM_BHWVCeJw,2399
57
+ cuthbertlib/quadrature/utils.py,sha256=15s5na_MhRJOeaorvvetw1NnoYWD-9-xMdK4ZLIo9kg,3652
58
+ cuthbertlib/resampling/__init__.py,sha256=Po73bCYwl-yAfYMRkrxeFmXMXdhwQ8btZqJ4fiH5qPc,200
59
+ cuthbertlib/resampling/killing.py,sha256=sPw0r86wfR9ERjbGuXiIwYo9nIqX-Y83e7magv53J7I,2817
60
+ cuthbertlib/resampling/multinomial.py,sha256=pGSKkfPuW3Hh815u_w0x47stg033H6XkeJbHzqcmJCE,2002
61
+ cuthbertlib/resampling/protocols.py,sha256=i2EUyLOxxKonbx-9xU8ZjjpDimIu1O8zdreEUPftdzI,2457
62
+ cuthbertlib/resampling/systematic.py,sha256=9eRRQjgShZUs98ATAUz4BYvudWEk0z1P73sLLlVZMR0,2338
63
+ cuthbertlib/resampling/utils.py,sha256=hmf1PmcDi1kWOFBUS0GWoP8t2Q5T_PkFWgcNnjD6sXs,2689
64
+ cuthbertlib/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
65
+ cuthbertlib/smc/ess.py,sha256=Lkafpy2QVt7EuxZYjLQHGzff_qfsVdouA4MUpRhMuS8,645
66
+ cuthbertlib/smc/smoothing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
+ cuthbertlib/smc/smoothing/exact_sampling.py,sha256=lUqM64l4nGe34w7ex2l5AXg-Oa67Bkzm07ow4mf0ZyM,3590
68
+ cuthbertlib/smc/smoothing/mcmc.py,sha256=fsLOzRTYRkgJgSxtCfp8hznG48tRXuA04VCxVuLN9Bo,2490
69
+ cuthbertlib/smc/smoothing/protocols.py,sha256=461TVhQbYIqDWBl7QRjGA35YpbDk4f6-BwtgGfgT-tU,1302
70
+ cuthbertlib/smc/smoothing/tracing.py,sha256=E8GFM0cOQiFspOFaAaK2xtyA5oMD6pbO9JqEk-O8RPk,1343
71
+ cuthbertlib/stats/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
+ cuthbertlib/stats/multivariate_normal.py,sha256=H5wGxtuHKJldJHs_PicsCX94hczeX9Xek625RdXrd18,4314
73
+ cuthbert-0.0.3.dist-info/METADATA,sha256=9zXQmkEBc-fq8-2vymg9ThVhYKdOLswtnuS9e2TkN2s,7116
74
+ cuthbert-0.0.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
75
+ cuthbert-0.0.3.dist-info/top_level.txt,sha256=R7-G6fUQZSMNMcM4-IcpHKtY9CZOQeejEpVW9q-Sarw,21
76
+ cuthbert-0.0.3.dist-info/RECORD,,
@@ -186,7 +186,7 @@
186
186
  same "printed page" as the copyright notice for easier
187
187
  identification within third-party archives.
188
188
 
189
- Copyright [yyyy] [name of copyright owner]
189
+ Copyright 2026 state-space-models org https://github.com/state-space-models
190
190
 
191
191
  Licensed under the Apache License, Version 2.0 (the "License");
192
192
  you may not use this file except in compliance with the License.
File without changes