cuthbert 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cuthbert/discrete/__init__.py +2 -0
- cuthbert/discrete/filter.py +140 -0
- cuthbert/discrete/smoother.py +123 -0
- cuthbert/discrete/types.py +53 -0
- cuthbert/gaussian/__init__.py +0 -0
- cuthbert/gaussian/kalman.py +337 -0
- cuthbert/gaussian/moments/__init__.py +11 -0
- cuthbert/gaussian/moments/associative_filter.py +180 -0
- cuthbert/gaussian/moments/filter.py +95 -0
- cuthbert/gaussian/moments/non_associative_filter.py +161 -0
- cuthbert/gaussian/moments/smoother.py +118 -0
- cuthbert/gaussian/moments/types.py +51 -0
- cuthbert/gaussian/taylor/__init__.py +14 -0
- cuthbert/gaussian/taylor/associative_filter.py +222 -0
- cuthbert/gaussian/taylor/filter.py +129 -0
- cuthbert/gaussian/taylor/non_associative_filter.py +246 -0
- cuthbert/gaussian/taylor/smoother.py +158 -0
- cuthbert/gaussian/taylor/types.py +86 -0
- cuthbert/gaussian/types.py +57 -0
- cuthbert/gaussian/utils.py +41 -0
- cuthbert/smc/__init__.py +0 -0
- cuthbert/smc/backward_sampler.py +193 -0
- cuthbert/smc/marginal_particle_filter.py +237 -0
- cuthbert/smc/particle_filter.py +234 -0
- cuthbert/smc/types.py +67 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/METADATA +2 -2
- cuthbert-0.0.3.dist-info/RECORD +76 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/licenses/LICENSE +1 -1
- cuthbertlib/discrete/__init__.py +0 -0
- cuthbertlib/discrete/filtering.py +49 -0
- cuthbertlib/discrete/smoothing.py +35 -0
- cuthbertlib/kalman/__init__.py +4 -0
- cuthbertlib/kalman/filtering.py +213 -0
- cuthbertlib/kalman/generate.py +85 -0
- cuthbertlib/kalman/sampling.py +68 -0
- cuthbertlib/kalman/smoothing.py +121 -0
- cuthbertlib/linalg/__init__.py +7 -0
- cuthbertlib/linalg/collect_nans_chol.py +90 -0
- cuthbertlib/linalg/marginal_sqrt_cov.py +34 -0
- cuthbertlib/linalg/symmetric_inv_sqrt.py +126 -0
- cuthbertlib/linalg/tria.py +21 -0
- cuthbertlib/linearize/__init__.py +7 -0
- cuthbertlib/linearize/log_density.py +175 -0
- cuthbertlib/linearize/moments.py +94 -0
- cuthbertlib/linearize/taylor.py +83 -0
- cuthbertlib/quadrature/__init__.py +4 -0
- cuthbertlib/quadrature/common.py +102 -0
- cuthbertlib/quadrature/cubature.py +73 -0
- cuthbertlib/quadrature/gauss_hermite.py +62 -0
- cuthbertlib/quadrature/linearize.py +143 -0
- cuthbertlib/quadrature/unscented.py +79 -0
- cuthbertlib/quadrature/utils.py +109 -0
- cuthbertlib/resampling/__init__.py +3 -0
- cuthbertlib/resampling/killing.py +79 -0
- cuthbertlib/resampling/multinomial.py +53 -0
- cuthbertlib/resampling/protocols.py +92 -0
- cuthbertlib/resampling/systematic.py +78 -0
- cuthbertlib/resampling/utils.py +82 -0
- cuthbertlib/smc/__init__.py +0 -0
- cuthbertlib/smc/ess.py +24 -0
- cuthbertlib/smc/smoothing/__init__.py +0 -0
- cuthbertlib/smc/smoothing/exact_sampling.py +111 -0
- cuthbertlib/smc/smoothing/mcmc.py +76 -0
- cuthbertlib/smc/smoothing/protocols.py +44 -0
- cuthbertlib/smc/smoothing/tracing.py +45 -0
- cuthbertlib/stats/__init__.py +0 -0
- cuthbertlib/stats/multivariate_normal.py +102 -0
- cuthbert-0.0.1.dist-info/RECORD +0 -12
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/WHEEL +0 -0
- {cuthbert-0.0.1.dist-info → cuthbert-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Implements the marginal particle filter.
|
|
2
|
+
|
|
3
|
+
See [Klaas et. al. (2005)](https://www.cs.ubc.ca/~arnaud/klass_defreitas_doucet_marginalparticlefilterUAI2005.pdf)
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import NamedTuple
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jax import Array, random, tree
|
|
12
|
+
|
|
13
|
+
from cuthbert.inference import Filter
|
|
14
|
+
from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
|
|
15
|
+
from cuthbert.utils import dummy_tree_like
|
|
16
|
+
from cuthbertlib.resampling import Resampling
|
|
17
|
+
from cuthbertlib.smc.ess import log_ess
|
|
18
|
+
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class MarginalParticleFilterState(NamedTuple):
|
|
22
|
+
"""Marginal particle filter state."""
|
|
23
|
+
|
|
24
|
+
# no ancestors, as it does not make sense for marginal particle filters
|
|
25
|
+
key: KeyArray
|
|
26
|
+
particles: ArrayTree
|
|
27
|
+
log_weights: Array
|
|
28
|
+
model_inputs: ArrayTree
|
|
29
|
+
log_normalizing_constant: ScalarArray
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def build_filter(
|
|
33
|
+
init_sample: InitSample,
|
|
34
|
+
propagate_sample: PropagateSample,
|
|
35
|
+
log_potential: LogPotential,
|
|
36
|
+
n_filter_particles: int,
|
|
37
|
+
resampling_fn: Resampling,
|
|
38
|
+
ess_threshold: float,
|
|
39
|
+
) -> Filter:
|
|
40
|
+
r"""Builds a marginal particle filter object.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
init_sample: Function to sample from the initial distribution $M_0(x_0)$.
|
|
44
|
+
propagate_sample: Function to sample from the Markov kernel $M_t(x_t \mid x_{t-1})$.
|
|
45
|
+
log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
|
|
46
|
+
n_filter_particles: Number of particles for the filter.
|
|
47
|
+
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
|
|
48
|
+
ess_threshold: Fraction of particle count specifying when to resample.
|
|
49
|
+
Resampling is triggered when the
|
|
50
|
+
effective sample size (ESS) < ess_threshold * n_filter_particles.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Filter object for the particle filter.
|
|
54
|
+
"""
|
|
55
|
+
return Filter(
|
|
56
|
+
init_prepare=partial(
|
|
57
|
+
init_prepare,
|
|
58
|
+
init_sample=init_sample,
|
|
59
|
+
log_potential=log_potential,
|
|
60
|
+
n_filter_particles=n_filter_particles,
|
|
61
|
+
),
|
|
62
|
+
filter_prepare=partial(
|
|
63
|
+
filter_prepare,
|
|
64
|
+
init_sample=init_sample,
|
|
65
|
+
n_filter_particles=n_filter_particles,
|
|
66
|
+
),
|
|
67
|
+
filter_combine=partial(
|
|
68
|
+
filter_combine,
|
|
69
|
+
propagate_sample=propagate_sample,
|
|
70
|
+
log_potential=log_potential,
|
|
71
|
+
resampling_fn=resampling_fn,
|
|
72
|
+
ess_threshold=ess_threshold,
|
|
73
|
+
),
|
|
74
|
+
associative=False,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def init_prepare(
|
|
79
|
+
model_inputs: ArrayTreeLike,
|
|
80
|
+
init_sample: InitSample,
|
|
81
|
+
log_potential: LogPotential,
|
|
82
|
+
n_filter_particles: int,
|
|
83
|
+
key: KeyArray | None = None,
|
|
84
|
+
) -> MarginalParticleFilterState:
|
|
85
|
+
"""Prepare the initial state for the marginal particle filter.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
model_inputs: Model inputs.
|
|
89
|
+
init_sample: Function to sample from the initial distribution M_0(x_0).
|
|
90
|
+
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
|
|
91
|
+
x_{t-1} is None since there is no previous state at t=0.
|
|
92
|
+
n_filter_particles: Number of particles to sample.
|
|
93
|
+
key: JAX random key.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Initial state for the filter.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If `key` is None.
|
|
100
|
+
"""
|
|
101
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
102
|
+
if key is None:
|
|
103
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
104
|
+
|
|
105
|
+
# Sample
|
|
106
|
+
keys = random.split(key, n_filter_particles)
|
|
107
|
+
particles = jax.vmap(init_sample, (0, None))(keys, model_inputs)
|
|
108
|
+
|
|
109
|
+
# Weight
|
|
110
|
+
log_weights = jax.vmap(log_potential, (None, 0, None))(
|
|
111
|
+
None, particles, model_inputs
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Compute the log normalizing constant
|
|
115
|
+
log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(
|
|
116
|
+
n_filter_particles
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
return MarginalParticleFilterState(
|
|
120
|
+
key=key,
|
|
121
|
+
particles=particles,
|
|
122
|
+
log_weights=log_weights,
|
|
123
|
+
model_inputs=model_inputs,
|
|
124
|
+
log_normalizing_constant=log_normalizing_constant,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def filter_prepare(
|
|
129
|
+
model_inputs: ArrayTreeLike,
|
|
130
|
+
init_sample: InitSample,
|
|
131
|
+
n_filter_particles: int,
|
|
132
|
+
key: KeyArray | None = None,
|
|
133
|
+
) -> MarginalParticleFilterState:
|
|
134
|
+
"""Prepare a state for a marginal particle filter step.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
model_inputs: Model inputs.
|
|
138
|
+
init_sample: Function to sample from the initial distribution M_0(x_0).
|
|
139
|
+
Only used to infer particle shapes.
|
|
140
|
+
n_filter_particles: Number of particles for the filter.
|
|
141
|
+
key: JAX random key.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Prepared state for the filter.
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
ValueError: If `key` is None.
|
|
148
|
+
"""
|
|
149
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
150
|
+
if key is None:
|
|
151
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
152
|
+
dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
|
|
153
|
+
particles = tree.map(
|
|
154
|
+
lambda x: jnp.empty((n_filter_particles,) + x.shape), dummy_particle
|
|
155
|
+
)
|
|
156
|
+
particles = dummy_tree_like(particles)
|
|
157
|
+
return MarginalParticleFilterState(
|
|
158
|
+
key=key,
|
|
159
|
+
particles=particles,
|
|
160
|
+
log_weights=jnp.zeros(n_filter_particles),
|
|
161
|
+
model_inputs=model_inputs,
|
|
162
|
+
log_normalizing_constant=jnp.array(0.0),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def filter_combine(
|
|
167
|
+
state_1: MarginalParticleFilterState,
|
|
168
|
+
state_2: MarginalParticleFilterState,
|
|
169
|
+
propagate_sample: PropagateSample,
|
|
170
|
+
log_potential: LogPotential,
|
|
171
|
+
resampling_fn: Resampling,
|
|
172
|
+
ess_threshold: float,
|
|
173
|
+
) -> MarginalParticleFilterState:
|
|
174
|
+
"""Combine previous filter state with the state prepared for the current step.
|
|
175
|
+
|
|
176
|
+
Implements the marginal particle filter update: conditional resampling,
|
|
177
|
+
propagation through state dynamics, and N^2 reweighting based on the
|
|
178
|
+
potential function.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
state_1: Filter state from the previous time step.
|
|
182
|
+
state_2: Filter state prepared for the current step.
|
|
183
|
+
propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
|
|
184
|
+
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
|
|
185
|
+
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
|
|
186
|
+
ess_threshold: Fraction of particle count specifying when to resample.
|
|
187
|
+
Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
The filtered state at the current time step.
|
|
191
|
+
"""
|
|
192
|
+
N = state_1.log_weights.shape[0]
|
|
193
|
+
keys = random.split(state_1.key, N + 1)
|
|
194
|
+
|
|
195
|
+
# Resample
|
|
196
|
+
prev_log_weights = state_1.log_weights - jax.nn.logsumexp(
|
|
197
|
+
state_1.log_weights
|
|
198
|
+
) # Ensure normalized
|
|
199
|
+
ancestor_indices, log_weights = jax.lax.cond(
|
|
200
|
+
log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
|
|
201
|
+
lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
|
|
202
|
+
lambda: (jnp.arange(N), state_1.log_weights),
|
|
203
|
+
)
|
|
204
|
+
ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)
|
|
205
|
+
|
|
206
|
+
# Propagate
|
|
207
|
+
next_particles = jax.vmap(propagate_sample, (0, 0, None))(
|
|
208
|
+
keys[1:], ancestors, state_2.model_inputs
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# N^2 Reweight by comparing all ancestors with all next particles
|
|
212
|
+
log_potential_vmapped = jax.vmap(
|
|
213
|
+
jax.vmap(log_potential, (0, None, None), out_axes=0),
|
|
214
|
+
(None, 0, None),
|
|
215
|
+
out_axes=0,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
log_potentials = log_potential_vmapped(
|
|
219
|
+
state_1.particles, next_particles, state_2.model_inputs
|
|
220
|
+
)
|
|
221
|
+
next_log_weights = log_potentials + prev_log_weights[None, :]
|
|
222
|
+
next_log_weights = jax.nn.logsumexp(next_log_weights, axis=1)
|
|
223
|
+
|
|
224
|
+
# Compute the log normalizing constant
|
|
225
|
+
logsum_weights = jax.nn.logsumexp(next_log_weights)
|
|
226
|
+
log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights)
|
|
227
|
+
log_normalizing_constant = (
|
|
228
|
+
log_normalizing_constant_incr + state_1.log_normalizing_constant
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return MarginalParticleFilterState(
|
|
232
|
+
key=state_2.key,
|
|
233
|
+
particles=next_particles,
|
|
234
|
+
log_weights=next_log_weights,
|
|
235
|
+
model_inputs=state_2.model_inputs,
|
|
236
|
+
log_normalizing_constant=log_normalizing_constant,
|
|
237
|
+
)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Implements the generic particle filter.
|
|
2
|
+
|
|
3
|
+
See Algorithm 10.1, [Chopin and Papaspiliopoulos, 2020](https://doi.org/10.1007/978-3-030-47845-2).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from functools import partial
|
|
7
|
+
from typing import NamedTuple
|
|
8
|
+
|
|
9
|
+
import jax
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from jax import Array, random, tree
|
|
12
|
+
|
|
13
|
+
from cuthbert.inference import Filter
|
|
14
|
+
from cuthbert.smc.types import InitSample, LogPotential, PropagateSample
|
|
15
|
+
from cuthbert.utils import dummy_tree_like
|
|
16
|
+
from cuthbertlib.resampling import Resampling
|
|
17
|
+
from cuthbertlib.smc.ess import log_ess
|
|
18
|
+
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ParticleFilterState(NamedTuple):
|
|
22
|
+
"""Particle filter state."""
|
|
23
|
+
|
|
24
|
+
key: KeyArray
|
|
25
|
+
particles: ArrayTree
|
|
26
|
+
log_weights: Array
|
|
27
|
+
ancestor_indices: Array
|
|
28
|
+
model_inputs: ArrayTree
|
|
29
|
+
log_normalizing_constant: ScalarArray
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def n_particles(self) -> int:
|
|
33
|
+
"""Number of particles in the filter state."""
|
|
34
|
+
return self.log_weights.shape[-1]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def build_filter(
|
|
38
|
+
init_sample: InitSample,
|
|
39
|
+
propagate_sample: PropagateSample,
|
|
40
|
+
log_potential: LogPotential,
|
|
41
|
+
n_filter_particles: int,
|
|
42
|
+
resampling_fn: Resampling,
|
|
43
|
+
ess_threshold: float,
|
|
44
|
+
) -> Filter:
|
|
45
|
+
r"""Builds a particle filter object.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
init_sample: Function to sample from the initial distribution $M_0(x_0)$.
|
|
49
|
+
propagate_sample: Function to sample from the Markov kernel $M_t(x_t \mid x_{t-1})$.
|
|
50
|
+
log_potential: Function to compute the log potential $\log G_t(x_{t-1}, x_t)$.
|
|
51
|
+
n_filter_particles: Number of particles for the filter.
|
|
52
|
+
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
|
|
53
|
+
ess_threshold: Fraction of particle count specifying when to resample.
|
|
54
|
+
Resampling is triggered when the
|
|
55
|
+
effective sample size (ESS) < ess_threshold * n_filter_particles.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Filter object for the particle filter.
|
|
59
|
+
"""
|
|
60
|
+
return Filter(
|
|
61
|
+
init_prepare=partial(
|
|
62
|
+
init_prepare,
|
|
63
|
+
init_sample=init_sample,
|
|
64
|
+
log_potential=log_potential,
|
|
65
|
+
n_filter_particles=n_filter_particles,
|
|
66
|
+
),
|
|
67
|
+
filter_prepare=partial(
|
|
68
|
+
filter_prepare,
|
|
69
|
+
init_sample=init_sample,
|
|
70
|
+
n_filter_particles=n_filter_particles,
|
|
71
|
+
),
|
|
72
|
+
filter_combine=partial(
|
|
73
|
+
filter_combine,
|
|
74
|
+
propagate_sample=propagate_sample,
|
|
75
|
+
log_potential=log_potential,
|
|
76
|
+
resampling_fn=resampling_fn,
|
|
77
|
+
ess_threshold=ess_threshold,
|
|
78
|
+
),
|
|
79
|
+
associative=False,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def init_prepare(
|
|
84
|
+
model_inputs: ArrayTreeLike,
|
|
85
|
+
init_sample: InitSample,
|
|
86
|
+
log_potential: LogPotential,
|
|
87
|
+
n_filter_particles: int,
|
|
88
|
+
key: KeyArray | None = None,
|
|
89
|
+
) -> ParticleFilterState:
|
|
90
|
+
"""Prepare the initial state for the particle filter.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model_inputs: Model inputs.
|
|
94
|
+
init_sample: Function to sample from the initial distribution M_0(x_0).
|
|
95
|
+
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
|
|
96
|
+
x_{t-1} is None since there is no previous state at t=0.
|
|
97
|
+
n_filter_particles: Number of particles to sample.
|
|
98
|
+
key: JAX random key.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Initial state for the particle filter.
|
|
102
|
+
|
|
103
|
+
Raises:
|
|
104
|
+
ValueError: If `key` is None.
|
|
105
|
+
"""
|
|
106
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
107
|
+
if key is None:
|
|
108
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
109
|
+
|
|
110
|
+
# Sample
|
|
111
|
+
keys = random.split(key, n_filter_particles)
|
|
112
|
+
particles = jax.vmap(init_sample, (0, None))(keys, model_inputs)
|
|
113
|
+
|
|
114
|
+
# Weight
|
|
115
|
+
log_weights = jax.vmap(log_potential, (None, 0, None))(
|
|
116
|
+
None, particles, model_inputs
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Compute the log normalizing constant
|
|
120
|
+
log_normalizing_constant = jax.nn.logsumexp(log_weights) - jnp.log(
|
|
121
|
+
n_filter_particles
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return ParticleFilterState(
|
|
125
|
+
key=key,
|
|
126
|
+
particles=particles,
|
|
127
|
+
log_weights=log_weights,
|
|
128
|
+
ancestor_indices=jnp.arange(n_filter_particles),
|
|
129
|
+
model_inputs=model_inputs,
|
|
130
|
+
log_normalizing_constant=log_normalizing_constant,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def filter_prepare(
|
|
135
|
+
model_inputs: ArrayTreeLike,
|
|
136
|
+
init_sample: InitSample,
|
|
137
|
+
n_filter_particles: int,
|
|
138
|
+
key: KeyArray | None = None,
|
|
139
|
+
) -> ParticleFilterState:
|
|
140
|
+
"""Prepare a state for a particle filter step.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_inputs: Model inputs.
|
|
144
|
+
init_sample: Function to sample from the initial distribution M_0(x_0).
|
|
145
|
+
Only used to infer particle shapes.
|
|
146
|
+
n_filter_particles: Number of particles for the filter.
|
|
147
|
+
key: JAX random key.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Prepared state for the filter.
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
ValueError: If `key` is None.
|
|
154
|
+
"""
|
|
155
|
+
model_inputs = tree.map(lambda x: jnp.asarray(x), model_inputs)
|
|
156
|
+
if key is None:
|
|
157
|
+
raise ValueError("A JAX PRNG key must be provided.")
|
|
158
|
+
dummy_particle = jax.eval_shape(init_sample, key, model_inputs)
|
|
159
|
+
particles = tree.map(
|
|
160
|
+
lambda x: jnp.empty((n_filter_particles,) + x.shape), dummy_particle
|
|
161
|
+
)
|
|
162
|
+
particles = dummy_tree_like(particles)
|
|
163
|
+
return ParticleFilterState(
|
|
164
|
+
key=key,
|
|
165
|
+
particles=particles,
|
|
166
|
+
log_weights=jnp.zeros(n_filter_particles),
|
|
167
|
+
ancestor_indices=jnp.arange(n_filter_particles),
|
|
168
|
+
model_inputs=model_inputs,
|
|
169
|
+
log_normalizing_constant=jnp.array(0.0),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def filter_combine(
|
|
174
|
+
state_1: ParticleFilterState,
|
|
175
|
+
state_2: ParticleFilterState,
|
|
176
|
+
propagate_sample: PropagateSample,
|
|
177
|
+
log_potential: LogPotential,
|
|
178
|
+
resampling_fn: Resampling,
|
|
179
|
+
ess_threshold: float,
|
|
180
|
+
) -> ParticleFilterState:
|
|
181
|
+
"""Combine previous filter state with the state prepared for the current step.
|
|
182
|
+
|
|
183
|
+
Implements the particle filter update: conditional resampling,
|
|
184
|
+
propagation through state dynamics, and reweighting based on the potential function.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
state_1: Filter state from the previous time step.
|
|
188
|
+
state_2: Filter state prepared for the current step.
|
|
189
|
+
propagate_sample: Function to sample from the Markov kernel M_t(x_t | x_{t-1}).
|
|
190
|
+
log_potential: Function to compute the log potential log G_t(x_{t-1}, x_t).
|
|
191
|
+
resampling_fn: Resampling algorithm to use (e.g., systematic, multinomial).
|
|
192
|
+
ess_threshold: Fraction of particle count specifying when to resample.
|
|
193
|
+
Resampling is triggered when the effective sample size (ESS) < ess_threshold * N.
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
The filtered state at the current time step.
|
|
197
|
+
"""
|
|
198
|
+
N = state_1.log_weights.shape[0]
|
|
199
|
+
keys = random.split(state_1.key, N + 1)
|
|
200
|
+
|
|
201
|
+
# Resample
|
|
202
|
+
ancestor_indices, log_weights = jax.lax.cond(
|
|
203
|
+
log_ess(state_1.log_weights) < jnp.log(ess_threshold * N),
|
|
204
|
+
lambda: (resampling_fn(keys[0], state_1.log_weights, N), jnp.zeros(N)),
|
|
205
|
+
lambda: (jnp.arange(N), state_1.log_weights),
|
|
206
|
+
)
|
|
207
|
+
ancestors = tree.map(lambda x: x[ancestor_indices], state_1.particles)
|
|
208
|
+
|
|
209
|
+
# Propagate
|
|
210
|
+
next_particles = jax.vmap(propagate_sample, (0, 0, None))(
|
|
211
|
+
keys[1:], ancestors, state_2.model_inputs
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Reweight
|
|
215
|
+
log_potentials = jax.vmap(log_potential, (0, 0, None))(
|
|
216
|
+
ancestors, next_particles, state_2.model_inputs
|
|
217
|
+
)
|
|
218
|
+
next_log_weights = log_potentials + log_weights
|
|
219
|
+
|
|
220
|
+
# Compute the log normalizing constant
|
|
221
|
+
logsum_weights = jax.nn.logsumexp(next_log_weights)
|
|
222
|
+
log_normalizing_constant_incr = logsum_weights - jax.nn.logsumexp(log_weights)
|
|
223
|
+
log_normalizing_constant = (
|
|
224
|
+
log_normalizing_constant_incr + state_1.log_normalizing_constant
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return ParticleFilterState(
|
|
228
|
+
state_2.key,
|
|
229
|
+
next_particles,
|
|
230
|
+
next_log_weights,
|
|
231
|
+
ancestor_indices,
|
|
232
|
+
state_2.model_inputs,
|
|
233
|
+
log_normalizing_constant,
|
|
234
|
+
)
|
cuthbert/smc/types.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
r"""Provides types for representing generic Feynman--Kac models.
|
|
2
|
+
|
|
3
|
+
$$
|
|
4
|
+
\mathbb{Q}_{t}(x_{0:t}) \propto \mathbb{M}_0(x_0) \, G_0(x_0) \prod_{s=1}^{t} M_s(x_s \mid x_{s-1}) \, G_s(x_{s-1}, x_s).
|
|
5
|
+
$$
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Protocol
|
|
9
|
+
|
|
10
|
+
from cuthbertlib.types import ArrayTree, ArrayTreeLike, KeyArray, ScalarArray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class InitSample(Protocol):
|
|
14
|
+
"""Protocol for sampling from the initial distribution $M_0(x_0)$."""
|
|
15
|
+
|
|
16
|
+
def __call__(self, key: KeyArray, model_inputs: ArrayTreeLike) -> ArrayTree:
|
|
17
|
+
"""Samples from the initial distribution $M_0(x_0)$.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
key: JAX PRNG key.
|
|
21
|
+
model_inputs: Model inputs.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A sample $x_0$.
|
|
25
|
+
"""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PropagateSample(Protocol):
|
|
30
|
+
r"""Protocol for sampling from the Markov kernel $M_t(x_t \mid x_{t-1})$."""
|
|
31
|
+
|
|
32
|
+
def __call__(
|
|
33
|
+
self, key: KeyArray, state: ArrayTreeLike, model_inputs: ArrayTreeLike
|
|
34
|
+
) -> ArrayTree:
|
|
35
|
+
r"""Samples from the Markov kernel $M_t(x_t \mid x_{t-1})$.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
key: JAX PRNG key.
|
|
39
|
+
state: State at the previous step $x_{t-1}$.
|
|
40
|
+
model_inputs: Model inputs.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
A sample $x_t$.
|
|
44
|
+
"""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LogPotential(Protocol):
|
|
49
|
+
r"""Protocol for computing the log potential function $\log G_t(x_{t-1}, x_t)$."""
|
|
50
|
+
|
|
51
|
+
def __call__(
|
|
52
|
+
self,
|
|
53
|
+
state_prev: ArrayTreeLike,
|
|
54
|
+
state: ArrayTreeLike,
|
|
55
|
+
model_inputs: ArrayTreeLike,
|
|
56
|
+
) -> ScalarArray:
|
|
57
|
+
r"""Computes the log potential function $\log G_t(x_{t-1}, x_t)$.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state_prev: State at the previous step $x_{t-1}$.
|
|
61
|
+
state: State at the current step $x_{t}$.
|
|
62
|
+
model_inputs: Model inputs.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
A scalar value $\log G_t(x_{t-1}, x_t)$.
|
|
66
|
+
"""
|
|
67
|
+
...
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cuthbert
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: State-space model inference with JAX
|
|
5
5
|
Author-email: Sam Duffield <s@mduffield.com>, Sahel Iqbal <sahel13miqbal@proton.me>, Adrien Corenflos <adrien.corenflos.stats@gmail.com>
|
|
6
6
|
License: Apache-2.0
|
|
@@ -33,7 +33,7 @@ Dynamic: license-file
|
|
|
33
33
|
|
|
34
34
|
<!--intro-start-->
|
|
35
35
|
<div align="center">
|
|
36
|
-
<img src="docs/assets/cuthbert.png" alt="logo"></img>
|
|
36
|
+
<img src="https://raw.githubusercontent.com/state-space-models/cuthbert/main/docs/assets/cuthbert.png" alt="cuthbert logo"></img>
|
|
37
37
|
</div>
|
|
38
38
|
|
|
39
39
|
A JAX library for state-space model inference
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
cuthbert/__init__.py,sha256=60_FB1nfduZLPphbjEc7WRpebCnVYiwMyKuD-7yZBvw,281
|
|
2
|
+
cuthbert/filtering.py,sha256=HaUPJWhBO8P5IWlRYhLwCeOEtlYenAvWjYsQaF_cPX0,2700
|
|
3
|
+
cuthbert/inference.py,sha256=u02wVKGu7mIdqn2XhcSZ93xXyPIQkxzSvxJXMH7Zo5k,7600
|
|
4
|
+
cuthbert/smoothing.py,sha256=qvNAWYTGGaEUEBp7HAtYLtF1UOQK9E7u3V_l4XgxeaI,4960
|
|
5
|
+
cuthbert/utils.py,sha256=0JQgRiyVs4SXZ0ullh4OmAMtyoOtuCzvYuDmpu9jXOE,1023
|
|
6
|
+
cuthbert/discrete/__init__.py,sha256=AI-GYq2QRyg2hwrEJm-LJRV9aMZ7MY-QkE0wG41Qg_I,104
|
|
7
|
+
cuthbert/discrete/filter.py,sha256=wpefQ3Q6eN6PdXKon93E5tjJLxBxPjqE4ctrd-QBbHA,4674
|
|
8
|
+
cuthbert/discrete/smoother.py,sha256=sMS2NKt-507LYnUSzqd3GMCArcSO2-eDiINX_hvLNzY,4291
|
|
9
|
+
cuthbert/discrete/types.py,sha256=hdfxqBGRza45ni0_Lu2NNvqLga_3o0PQWQgJ08BcxhM,1412
|
|
10
|
+
cuthbert/gaussian/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
cuthbert/gaussian/kalman.py,sha256=WldTKTCzlAWX1btAD6kD63Qoz1RpTufOiq2riT6ND44,11567
|
|
12
|
+
cuthbert/gaussian/types.py,sha256=IjC8AWUpjLMgm1co2sc7knb3Otwz-pnUSJEw8p5ohGc,1725
|
|
13
|
+
cuthbert/gaussian/utils.py,sha256=AnRdw_Ju7rWXpbCaOB0CUkoVctJ8gScchgURv5mQ0BU,1301
|
|
14
|
+
cuthbert/gaussian/moments/__init__.py,sha256=cYUxe1Sksf0qqBYO-zwtKnuoZ_gOE_Hr7aGsrzeNkOQ,327
|
|
15
|
+
cuthbert/gaussian/moments/associative_filter.py,sha256=abFEs2KVbXrnIMUJMVKD32hRP8Q3A4NkQP_VmfSGeaI,6188
|
|
16
|
+
cuthbert/gaussian/moments/filter.py,sha256=8dG_Aqdvs_WC23JGpC3B5mTjG5uJMtd15LxJAjeyddY,3951
|
|
17
|
+
cuthbert/gaussian/moments/non_associative_filter.py,sha256=wV_IBeaA-EnWLnX5Wkp6-Kx9WDnpeVWtdUiqlXeGS1U,5942
|
|
18
|
+
cuthbert/gaussian/moments/smoother.py,sha256=L79R8WPwI30aHyCdZE1paYM9KNZtmQlRsvFhDyFsoxs,4262
|
|
19
|
+
cuthbert/gaussian/moments/types.py,sha256=UXUaeakgyO-WLcY8ECmFcf16RwOQZ6pz0B5qI621xV8,1727
|
|
20
|
+
cuthbert/gaussian/taylor/__init__.py,sha256=xOdFMzy59MfF1WuM7Oy4_H7Hn0TtekyoK-ECM5CVuEE,428
|
|
21
|
+
cuthbert/gaussian/taylor/associative_filter.py,sha256=daIY693BefoeV15T4hnOiJ2_e-eH2bGDrmA4pEwNB0A,8318
|
|
22
|
+
cuthbert/gaussian/taylor/filter.py,sha256=ljlHidVuFAlrVVSfmFM6sfdxQuYsmCVsSMucHCuFd3M,5303
|
|
23
|
+
cuthbert/gaussian/taylor/non_associative_filter.py,sha256=qhT31jWyCct_C0Ae3wNLm3BD75eBYX16xNDaA8uw-wg,9446
|
|
24
|
+
cuthbert/gaussian/taylor/smoother.py,sha256=hDAOTJc71oZZIH_ZRkVDXnbLj8riwNqncRy-N372wXA,5814
|
|
25
|
+
cuthbert/gaussian/taylor/types.py,sha256=h45seSNRUSgT4tpQVjqW_AJny2HCz6kJi9yM7zLEWOc,2733
|
|
26
|
+
cuthbert/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
+
cuthbert/smc/backward_sampler.py,sha256=rFsJvSvx3S2hRl2ojTwvwQqERtu_gaBEWzjQtO2GIXQ,6737
|
|
28
|
+
cuthbert/smc/marginal_particle_filter.py,sha256=91WaqsSmoo3Q5pPanYaLSnCOyKLMWSDH1aV3h1SkO2g,8125
|
|
29
|
+
cuthbert/smc/particle_filter.py,sha256=PebY0rt0C76lAZgoWkpIGzUkR_BxoRMCZtOFsRRVSok,7770
|
|
30
|
+
cuthbert/smc/types.py,sha256=jen13y7RWylzIAmrpPGpp980V7uukG06kkoEkcCGZJA,1841
|
|
31
|
+
cuthbert-0.0.3.dist-info/licenses/LICENSE,sha256=4bgVKT6ovc1ZpaqBTr3nDmQrN0rNe1Yx-WNW_q3Er7s,11390
|
|
32
|
+
cuthbertlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
|
+
cuthbertlib/types.py,sha256=sgHC0J7W_0wOzMm5dlLZXwJB3W1xvzq2J6cxl6U831w,1020
|
|
34
|
+
cuthbertlib/discrete/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
|
+
cuthbertlib/discrete/filtering.py,sha256=_TMGwAeDXHivANpck3HUXFFm8SUalm72N7NHPzYTan8,1466
|
|
36
|
+
cuthbertlib/discrete/smoothing.py,sha256=FZJkSpUhAo7-iM3FiudKvR5XNvNEFXhfzdkf_kIoNQM,1214
|
|
37
|
+
cuthbertlib/kalman/__init__.py,sha256=Tf2dJBk9qavIe_0OL_TJ2zfCTC0hBJrPS0OMbnx_p_o,221
|
|
38
|
+
cuthbertlib/kalman/filtering.py,sha256=YQvxzsF5XsYz8JEJIoHiETtcNHmcMCM4djanNn-pe4w,6757
|
|
39
|
+
cuthbertlib/kalman/generate.py,sha256=AeF-NXIpxN-E6YmiKIaUg4EK5wzWLkTtdg3K9RCemg4,3148
|
|
40
|
+
cuthbertlib/kalman/sampling.py,sha256=Cx4EPCXwMJlPt_48aLsCowk3t70dfA0Fp_lDUQih4ok,1914
|
|
41
|
+
cuthbertlib/kalman/smoothing.py,sha256=iCd5hEVrwVZ3Q2aQtJeNe1IKbpcXD7E1fqHy1P6bX_U,3928
|
|
42
|
+
cuthbertlib/linalg/__init__.py,sha256=ldnznmTBCJeyFfiz1nLQ1hN0X9PwcjnQOAXjpmL6Ag8,284
|
|
43
|
+
cuthbertlib/linalg/collect_nans_chol.py,sha256=7pC7lh5Oux3IiNndIOSYLUdMTAjrVIwkeK52hGnbmJg,3137
|
|
44
|
+
cuthbertlib/linalg/marginal_sqrt_cov.py,sha256=Hq0ofXdMzBANIhfwqKXUSU0IL6LwGZIvU3Uq7EfaZ7E,1188
|
|
45
|
+
cuthbertlib/linalg/symmetric_inv_sqrt.py,sha256=8CWJzaS9VGl3tlJHdw7B7q-DiwbWA5aqAab4i5ZJKU8,5043
|
|
46
|
+
cuthbertlib/linalg/tria.py,sha256=twzyyachwNGDlHA6i6mBl504vVqOYjeC_9a8fa_X_X4,571
|
|
47
|
+
cuthbertlib/linearize/__init__.py,sha256=JzeMqWch_mXJ4a1WaAYuQqQg7Z4gCvTJpZas4RqXwSE,287
|
|
48
|
+
cuthbertlib/linearize/log_density.py,sha256=VUwHFCWsvZwU22aPWD_YrUb3aGOQFP_NhyUqDDYmpJo,6073
|
|
49
|
+
cuthbertlib/linearize/moments.py,sha256=DiU8osX4x9htR3G_cxqMN9EcfBJ0xxSoR0Rn-2bY5WA,3205
|
|
50
|
+
cuthbertlib/linearize/taylor.py,sha256=9LUzoo0wuU5mReO4DUggwyFbSJpMuZ7UMqpmvnJyyrA,3238
|
|
51
|
+
cuthbertlib/quadrature/__init__.py,sha256=79moFPswxvR9InZrL3RysITfjVnYDUpVa5E2CJcdvW4,275
|
|
52
|
+
cuthbertlib/quadrature/common.py,sha256=HpNFVXgP_L4vzqpNtRE6XCmiFrBAsIj-9KThAENUSbM,2743
|
|
53
|
+
cuthbertlib/quadrature/cubature.py,sha256=1yiGLaij6JDPtttFBbqgNuvG6ZUgdHSbWek8d8HGmyw,1877
|
|
54
|
+
cuthbertlib/quadrature/gauss_hermite.py,sha256=LuapZIklRb7CpZI5StJ9BD7oBevD0JU4cEEkqqM10dE,1740
|
|
55
|
+
cuthbertlib/quadrature/linearize.py,sha256=y_GcnER8Btd-mn2GN52Wowq893PMcSxcJtnDyxbKNoA,5403
|
|
56
|
+
cuthbertlib/quadrature/unscented.py,sha256=SMI51HCESHLKEW4GGrYtVCmejkiP8WUxWM_BHWVCeJw,2399
|
|
57
|
+
cuthbertlib/quadrature/utils.py,sha256=15s5na_MhRJOeaorvvetw1NnoYWD-9-xMdK4ZLIo9kg,3652
|
|
58
|
+
cuthbertlib/resampling/__init__.py,sha256=Po73bCYwl-yAfYMRkrxeFmXMXdhwQ8btZqJ4fiH5qPc,200
|
|
59
|
+
cuthbertlib/resampling/killing.py,sha256=sPw0r86wfR9ERjbGuXiIwYo9nIqX-Y83e7magv53J7I,2817
|
|
60
|
+
cuthbertlib/resampling/multinomial.py,sha256=pGSKkfPuW3Hh815u_w0x47stg033H6XkeJbHzqcmJCE,2002
|
|
61
|
+
cuthbertlib/resampling/protocols.py,sha256=i2EUyLOxxKonbx-9xU8ZjjpDimIu1O8zdreEUPftdzI,2457
|
|
62
|
+
cuthbertlib/resampling/systematic.py,sha256=9eRRQjgShZUs98ATAUz4BYvudWEk0z1P73sLLlVZMR0,2338
|
|
63
|
+
cuthbertlib/resampling/utils.py,sha256=hmf1PmcDi1kWOFBUS0GWoP8t2Q5T_PkFWgcNnjD6sXs,2689
|
|
64
|
+
cuthbertlib/smc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
65
|
+
cuthbertlib/smc/ess.py,sha256=Lkafpy2QVt7EuxZYjLQHGzff_qfsVdouA4MUpRhMuS8,645
|
|
66
|
+
cuthbertlib/smc/smoothing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
67
|
+
cuthbertlib/smc/smoothing/exact_sampling.py,sha256=lUqM64l4nGe34w7ex2l5AXg-Oa67Bkzm07ow4mf0ZyM,3590
|
|
68
|
+
cuthbertlib/smc/smoothing/mcmc.py,sha256=fsLOzRTYRkgJgSxtCfp8hznG48tRXuA04VCxVuLN9Bo,2490
|
|
69
|
+
cuthbertlib/smc/smoothing/protocols.py,sha256=461TVhQbYIqDWBl7QRjGA35YpbDk4f6-BwtgGfgT-tU,1302
|
|
70
|
+
cuthbertlib/smc/smoothing/tracing.py,sha256=E8GFM0cOQiFspOFaAaK2xtyA5oMD6pbO9JqEk-O8RPk,1343
|
|
71
|
+
cuthbertlib/stats/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
72
|
+
cuthbertlib/stats/multivariate_normal.py,sha256=H5wGxtuHKJldJHs_PicsCX94hczeX9Xek625RdXrd18,4314
|
|
73
|
+
cuthbert-0.0.3.dist-info/METADATA,sha256=9zXQmkEBc-fq8-2vymg9ThVhYKdOLswtnuS9e2TkN2s,7116
|
|
74
|
+
cuthbert-0.0.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
75
|
+
cuthbert-0.0.3.dist-info/top_level.txt,sha256=R7-G6fUQZSMNMcM4-IcpHKtY9CZOQeejEpVW9q-Sarw,21
|
|
76
|
+
cuthbert-0.0.3.dist-info/RECORD,,
|
|
@@ -186,7 +186,7 @@
|
|
|
186
186
|
same "printed page" as the copyright notice for easier
|
|
187
187
|
identification within third-party archives.
|
|
188
188
|
|
|
189
|
-
Copyright
|
|
189
|
+
Copyright 2026 state-space-models org https://github.com/state-space-models
|
|
190
190
|
|
|
191
191
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
192
192
|
you may not use this file except in compliance with the License.
|
|
File without changes
|