pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright 2023 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
# Copyright 2023 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
import time
|
|
17
|
+
import warnings
|
|
18
|
+
|
|
19
|
+
from collections.abc import Callable
|
|
20
|
+
from typing import NamedTuple, cast
|
|
21
|
+
|
|
22
|
+
import arviz as az
|
|
23
|
+
import blackjax
|
|
24
|
+
import jax
|
|
25
|
+
import jax.numpy as jnp
|
|
26
|
+
import numpy as np
|
|
27
|
+
|
|
28
|
+
from blackjax.smc import extend_params
|
|
29
|
+
from blackjax.smc.resampling import systematic
|
|
30
|
+
from pymc import draw, modelcontext, to_inference_data
|
|
31
|
+
from pymc.backends import NDArray
|
|
32
|
+
from pymc.backends.base import MultiTrace
|
|
33
|
+
from pymc.initial_point import make_initial_point_expression
|
|
34
|
+
from pymc.sampling.jax import get_jaxified_graph
|
|
35
|
+
from pymc.util import RandomState, _get_seeds_per_chain
|
|
36
|
+
|
|
37
|
+
log = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def sample_smc_blackjax(
|
|
41
|
+
n_particles: int = 2000,
|
|
42
|
+
random_seed: RandomState = None,
|
|
43
|
+
kernel: str = "HMC",
|
|
44
|
+
target_essn: float = 0.5,
|
|
45
|
+
num_mcmc_steps: int = 10,
|
|
46
|
+
inner_kernel_params: dict | None = None,
|
|
47
|
+
model=None,
|
|
48
|
+
iterations_to_diagnose: int = 100,
|
|
49
|
+
):
|
|
50
|
+
"""Samples using BlackJax's implementation of Sequential Monte Carlo.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
n_particles: int
|
|
55
|
+
number of particles used to sample from the posterior. This is also the number of draws. Defaults to 2000.
|
|
56
|
+
random_seed: RandomState
|
|
57
|
+
seed used for random number generator, set for reproducibility. Otherwise a random one will be used (default).
|
|
58
|
+
kernel: str
|
|
59
|
+
Either 'HMC' (default) or 'NUTS'. The kernel to be used to mutate the particles in each SMC iteration.
|
|
60
|
+
target_essn: float
|
|
61
|
+
Proportion (0 < target_essn < 1) of the total number of particles, to be used for incrementing the exponent
|
|
62
|
+
of the tempered posterior between iterations. The higher the number, each increment is going to be smaller,
|
|
63
|
+
leading to more steps and computational cost. Defaults to 0.5. See https://arxiv.org/abs/1602.03572
|
|
64
|
+
num_mcmc_steps: int
|
|
65
|
+
fixed number of steps of each inner kernel markov chain for each SMC mutation step.
|
|
66
|
+
inner_kernel_params: Optional[dict]
|
|
67
|
+
a dictionary with parameters for the inner kernel.
|
|
68
|
+
For HMC it must have 'step_size' and 'integration_steps'
|
|
69
|
+
For NUTS it must have 'step_size'
|
|
70
|
+
these parameters are fixed for all iterations.
|
|
71
|
+
model:
|
|
72
|
+
PyMC model to sample from
|
|
73
|
+
iterations_to_diagnose: int
|
|
74
|
+
Number of iterations to generate diagnosis for. By default, will diagnose the first 100 iterations. Increase
|
|
75
|
+
this number for further diagnosis (it can be bigger than the actual number of iterations executed by the algorithm,
|
|
76
|
+
at the expense of allocating memory to store the diagnosis).
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
An Arviz Inference data.
|
|
81
|
+
|
|
82
|
+
Note
|
|
83
|
+
----
|
|
84
|
+
A summary of the algorithm is:
|
|
85
|
+
|
|
86
|
+
1. Initialize :math:`\beta` at zero and stage at zero.
|
|
87
|
+
2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the
|
|
88
|
+
tempered posterior is the prior).
|
|
89
|
+
3. Increase :math:`\beta` in order to make the effective sample size equal some predefined
|
|
90
|
+
value (target_essn)
|
|
91
|
+
4. Compute a set of N importance weights W. The weights are computed as the ratio of the
|
|
92
|
+
likelihoods of a sample at stage i+1 and stage i.
|
|
93
|
+
5. Obtain :math:`S_{w}` by re-sampling according to W.
|
|
94
|
+
6. Run N independent MCMC chains, starting each one from a different sample
|
|
95
|
+
in :math:`S_{w}`. For that, set the kernel and inner_kernel_params.
|
|
96
|
+
7. The N chains are run for num_mcmc_steps each.
|
|
97
|
+
8. Repeat from step 3 until :math:`\beta \\ge 1`.
|
|
98
|
+
9. The final result is a collection of N samples from the posterior
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
model = modelcontext(model)
|
|
103
|
+
random_seed = np.random.default_rng(seed=random_seed)
|
|
104
|
+
|
|
105
|
+
if inner_kernel_params is None:
|
|
106
|
+
inner_kernel_params = {}
|
|
107
|
+
|
|
108
|
+
log.info(
|
|
109
|
+
f"Will only diagnose the first {iterations_to_diagnose} SMC iterations,"
|
|
110
|
+
f"this number can be increased by setting iterations_to_diagnose parameter"
|
|
111
|
+
f" in sample_with_blackjax_smc"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
key = jax.random.PRNGKey(_get_seeds_per_chain(random_seed, 1)[0])
|
|
115
|
+
|
|
116
|
+
key, initial_particles_key, iterations_key = jax.random.split(key, 3)
|
|
117
|
+
|
|
118
|
+
initial_particles = blackjax_particles_from_pymc_population(
|
|
119
|
+
model, initialize_population(model, n_particles, random_seed)
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
var_map = var_map_from_model(
|
|
123
|
+
model, model.initial_point(random_seed=random_seed.integers(2**30))
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
posterior_dimensions = sum(var_map[k][1] for k in var_map)
|
|
127
|
+
|
|
128
|
+
if kernel == "HMC":
|
|
129
|
+
mcmc_kernel = blackjax.mcmc.hmc
|
|
130
|
+
mcmc_parameters = extend_params(
|
|
131
|
+
dict(
|
|
132
|
+
step_size=inner_kernel_params["step_size"],
|
|
133
|
+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
|
|
134
|
+
num_integration_steps=inner_kernel_params["integration_steps"],
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
elif kernel == "NUTS":
|
|
138
|
+
mcmc_kernel = blackjax.mcmc.nuts
|
|
139
|
+
mcmc_parameters = extend_params(
|
|
140
|
+
dict(
|
|
141
|
+
step_size=inner_kernel_params["step_size"],
|
|
142
|
+
inverse_mass_matrix=jnp.eye(posterior_dimensions),
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")
|
|
147
|
+
|
|
148
|
+
sampler = build_smc_with_kernel(
|
|
149
|
+
prior_log_prob=get_jaxified_logprior(model),
|
|
150
|
+
loglikelihood=get_jaxified_loglikelihood(model),
|
|
151
|
+
target_ess=target_essn,
|
|
152
|
+
num_mcmc_steps=num_mcmc_steps,
|
|
153
|
+
kernel_parameters=mcmc_parameters,
|
|
154
|
+
mcmc_kernel=mcmc_kernel,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
start = time.time()
|
|
158
|
+
total_iterations, particles, diagnosis = inference_loop(
|
|
159
|
+
iterations_key,
|
|
160
|
+
sampler.init(initial_particles),
|
|
161
|
+
sampler,
|
|
162
|
+
iterations_to_diagnose,
|
|
163
|
+
n_particles,
|
|
164
|
+
)
|
|
165
|
+
end = time.time()
|
|
166
|
+
running_time = end - start
|
|
167
|
+
|
|
168
|
+
inference_data = arviz_from_particles(model, particles)
|
|
169
|
+
|
|
170
|
+
add_to_inference_data(
|
|
171
|
+
inference_data,
|
|
172
|
+
n_particles,
|
|
173
|
+
target_essn,
|
|
174
|
+
num_mcmc_steps,
|
|
175
|
+
kernel,
|
|
176
|
+
diagnosis,
|
|
177
|
+
total_iterations,
|
|
178
|
+
iterations_to_diagnose,
|
|
179
|
+
inner_kernel_params,
|
|
180
|
+
running_time,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if total_iterations < iterations_to_diagnose:
|
|
184
|
+
log.warning(
|
|
185
|
+
f"Only the first {iterations_to_diagnose} were included in diagnosed quantities out of {total_iterations}."
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return inference_data
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def arviz_from_particles(model, particles):
|
|
192
|
+
"""
|
|
193
|
+
Given Particles in Blackjax format,
|
|
194
|
+
builds an Arviz Inference Data object.
|
|
195
|
+
In order to do so in a consistent way,
|
|
196
|
+
particles are assumed to be encoded in
|
|
197
|
+
model.value_vars order.
|
|
198
|
+
|
|
199
|
+
Parameters
|
|
200
|
+
----------
|
|
201
|
+
model: Pymc Model
|
|
202
|
+
particles: output of Blackjax SMC.
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
Returns an Arviz Inference Data Object
|
|
206
|
+
-------
|
|
207
|
+
"""
|
|
208
|
+
n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
|
|
209
|
+
by_varname = {
|
|
210
|
+
k.name: v.squeeze()[np.newaxis, :].astype(k.dtype)
|
|
211
|
+
for k, v in zip(model.value_vars, particles)
|
|
212
|
+
}
|
|
213
|
+
varnames = [v.name for v in model.value_vars]
|
|
214
|
+
with model:
|
|
215
|
+
strace = NDArray(name=model.name)
|
|
216
|
+
strace.setup(n_particles, 0)
|
|
217
|
+
for particle_index in range(0, n_particles):
|
|
218
|
+
strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames})
|
|
219
|
+
multitrace = MultiTrace((strace,))
|
|
220
|
+
return to_inference_data(multitrace, log_likelihood=False)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class SMCDiagnostics(NamedTuple):
|
|
224
|
+
"""
|
|
225
|
+
A Jax-compilable object to track
|
|
226
|
+
quantities of interest of an SMC run.
|
|
227
|
+
Note that initial_diagnosis and update_diagnosis
|
|
228
|
+
must return copies and not modify in place for the class
|
|
229
|
+
to be Jax Compilable, reason why they are static methods.
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
lmbda_evolution: jax.Array
|
|
233
|
+
log_likelihood_increment_evolution: jax.Array
|
|
234
|
+
ancestors_evolution: jax.Array
|
|
235
|
+
weights_evolution: jax.Array
|
|
236
|
+
|
|
237
|
+
@staticmethod
|
|
238
|
+
def update_diagnosis(i, history, info, state):
|
|
239
|
+
le, lli, ancestors, weights_evolution = history
|
|
240
|
+
return SMCDiagnostics(
|
|
241
|
+
le.at[i].set(state.lmbda),
|
|
242
|
+
lli.at[i].set(info.log_likelihood_increment),
|
|
243
|
+
ancestors.at[i].set(info.ancestors),
|
|
244
|
+
weights_evolution.at[i].set(state.weights),
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def initial_diagnosis(iterations_to_diagnose, n_particles):
|
|
249
|
+
return SMCDiagnostics(
|
|
250
|
+
jnp.zeros(iterations_to_diagnose),
|
|
251
|
+
jnp.zeros(iterations_to_diagnose),
|
|
252
|
+
jnp.zeros((iterations_to_diagnose, n_particles)),
|
|
253
|
+
jnp.zeros((iterations_to_diagnose, n_particles)),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def flatten_single_particle(particle):
|
|
258
|
+
return jnp.hstack([v.squeeze() for v in particle])
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_particles):
|
|
262
|
+
"""
|
|
263
|
+
SMC inference loop that keeps tracks of diagnosis quantities.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def cond(carry):
|
|
267
|
+
i, state, _, _ = carry
|
|
268
|
+
return state.lmbda < 1
|
|
269
|
+
|
|
270
|
+
def one_step(carry):
|
|
271
|
+
i, state, k, previous_info = carry
|
|
272
|
+
k, subk = jax.random.split(k, 2)
|
|
273
|
+
state, info = kernel.step(subk, state)
|
|
274
|
+
full_info = SMCDiagnostics.update_diagnosis(i, previous_info, info, state)
|
|
275
|
+
|
|
276
|
+
return i + 1, state, k, full_info
|
|
277
|
+
|
|
278
|
+
n_iter, final_state, _, diagnosis = jax.lax.while_loop(
|
|
279
|
+
cond,
|
|
280
|
+
one_step,
|
|
281
|
+
(
|
|
282
|
+
0,
|
|
283
|
+
initial_state,
|
|
284
|
+
rng_key,
|
|
285
|
+
SMCDiagnostics.initial_diagnosis(iterations_to_diagnose, n_particles),
|
|
286
|
+
),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return n_iter, final_state.particles, diagnosis
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def blackjax_particles_from_pymc_population(model, pymc_population):
|
|
293
|
+
"""
|
|
294
|
+
Transforms a pymc population of particles into the format
|
|
295
|
+
accepted by BlackJax. Particles must be a PyTree, each leave represents
|
|
296
|
+
a variable from the posterior, being an array of size n_particles
|
|
297
|
+
* the variable's dimensionality.
|
|
298
|
+
Note that the order in which variables are stored in the Pytree
|
|
299
|
+
must be the same order used to calculate the logprior and loglikelihood.
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
pymc_population : A dictionary with variables as keys, and arrays
|
|
304
|
+
with samples as values.
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
order_of_vars = model.value_vars
|
|
308
|
+
|
|
309
|
+
def _format(var):
|
|
310
|
+
variable = pymc_population[var.name]
|
|
311
|
+
if len(variable.shape) == 1:
|
|
312
|
+
return variable[:, np.newaxis]
|
|
313
|
+
else:
|
|
314
|
+
return variable
|
|
315
|
+
|
|
316
|
+
return [_format(var) for var in order_of_vars]
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def add_to_inference_data(
|
|
320
|
+
inference_data: az.InferenceData,
|
|
321
|
+
n_particles: int,
|
|
322
|
+
target_ess: float,
|
|
323
|
+
num_mcmc_steps: int,
|
|
324
|
+
kernel: str,
|
|
325
|
+
diagnosis: SMCDiagnostics,
|
|
326
|
+
total_iterations: int,
|
|
327
|
+
iterations_to_diagnose: int,
|
|
328
|
+
kernel_parameters: dict,
|
|
329
|
+
running_time_seconds: float,
|
|
330
|
+
):
|
|
331
|
+
"""
|
|
332
|
+
Adds several SMC parameters into the az.InferenceData result
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
inference_data: arviz object to add attributes to.
|
|
337
|
+
n_particles: number of particles present in the result
|
|
338
|
+
target_ess: target effective sampling size between SMC iterations, used
|
|
339
|
+
to calculate the tempering exponent
|
|
340
|
+
num_mcmc_steps: number of steps of the inner kernel when mutating particles
|
|
341
|
+
kernel: string representing the kernel used to mutate particles
|
|
342
|
+
diagnosis: SMCDiagnostics, containing quantities of interest for the full
|
|
343
|
+
SMC run
|
|
344
|
+
total_iterations: the total number of iterations executed by the sampler
|
|
345
|
+
iterations_to_diagnose: the number of iterations represented in the diagnosed
|
|
346
|
+
quantities
|
|
347
|
+
kernel_parameters: dict parameters from the inner kernel used to mutate particles
|
|
348
|
+
running_time_seconds: float sampling time
|
|
349
|
+
"""
|
|
350
|
+
experiment_parameters = {
|
|
351
|
+
"particles": n_particles,
|
|
352
|
+
"target_ess": target_ess,
|
|
353
|
+
"num_mcmc_steps": num_mcmc_steps,
|
|
354
|
+
"iterations": total_iterations,
|
|
355
|
+
"iterations_to_diagnose": iterations_to_diagnose,
|
|
356
|
+
"sampler": f"Blackjax SMC with {kernel} kernel",
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
inference_data.posterior.attrs["lambda_evolution"] = np.array(diagnosis.lmbda_evolution)[
|
|
360
|
+
:iterations_to_diagnose
|
|
361
|
+
]
|
|
362
|
+
inference_data.posterior.attrs["log_likelihood_increments"] = np.array(
|
|
363
|
+
diagnosis.log_likelihood_increment_evolution
|
|
364
|
+
)[:iterations_to_diagnose]
|
|
365
|
+
inference_data.posterior.attrs["ancestors_evolution"] = np.array(diagnosis.ancestors_evolution)[
|
|
366
|
+
:iterations_to_diagnose
|
|
367
|
+
]
|
|
368
|
+
inference_data.posterior.attrs["weights_evolution"] = np.array(diagnosis.weights_evolution)[
|
|
369
|
+
:iterations_to_diagnose
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
for k in experiment_parameters:
|
|
373
|
+
inference_data.posterior.attrs[k] = experiment_parameters[k]
|
|
374
|
+
|
|
375
|
+
for k in kernel_parameters:
|
|
376
|
+
inference_data.posterior.attrs[k] = kernel_parameters[k]
|
|
377
|
+
|
|
378
|
+
inference_data.posterior.attrs["running_time_seconds"] = running_time_seconds
|
|
379
|
+
|
|
380
|
+
return inference_data
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def get_jaxified_logprior(model) -> Callable:
|
|
384
|
+
return get_jaxified_particles_fn(model, model.varlogp)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def get_jaxified_loglikelihood(model) -> Callable:
|
|
388
|
+
return get_jaxified_particles_fn(model, model.datalogp)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def get_jaxified_particles_fn(model, graph_outputs):
|
|
392
|
+
"""
|
|
393
|
+
Builds a Jaxified version of a value_vars function,
|
|
394
|
+
that is applyable to Blackjax particles format.
|
|
395
|
+
"""
|
|
396
|
+
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs])
|
|
397
|
+
|
|
398
|
+
def logp_fn_wrap(particles):
|
|
399
|
+
return logp_fn(*[p.squeeze() for p in particles])[0]
|
|
400
|
+
|
|
401
|
+
return logp_fn_wrap
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def initialize_population(model, draws, random_seed) -> dict[str, np.ndarray]:
|
|
405
|
+
with warnings.catch_warnings():
|
|
406
|
+
warnings.filterwarnings("ignore", category=UserWarning, message="The effect of Potentials")
|
|
407
|
+
|
|
408
|
+
prior_expression = make_initial_point_expression(
|
|
409
|
+
free_rvs=model.free_RVs,
|
|
410
|
+
rvs_to_transforms=model.rvs_to_transforms,
|
|
411
|
+
initval_strategies={},
|
|
412
|
+
default_strategy="prior",
|
|
413
|
+
return_transformed=True,
|
|
414
|
+
)
|
|
415
|
+
prior_values = draw(prior_expression, draws=draws, random_seed=random_seed)
|
|
416
|
+
|
|
417
|
+
names = [model.rvs_to_values[rv].name for rv in model.free_RVs]
|
|
418
|
+
dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)}
|
|
419
|
+
|
|
420
|
+
return cast(dict[str, np.ndarray], dict_prior)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def var_map_from_model(model, initial_point) -> dict:
|
|
424
|
+
"""
|
|
425
|
+
Computes a dictionary that maps
|
|
426
|
+
variable names to tuples (shape, size)
|
|
427
|
+
"""
|
|
428
|
+
var_info = {}
|
|
429
|
+
for v in model.value_vars:
|
|
430
|
+
var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)
|
|
431
|
+
return var_info
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def build_smc_with_kernel(
|
|
435
|
+
prior_log_prob,
|
|
436
|
+
loglikelihood,
|
|
437
|
+
target_ess,
|
|
438
|
+
num_mcmc_steps,
|
|
439
|
+
kernel_parameters,
|
|
440
|
+
mcmc_kernel,
|
|
441
|
+
):
|
|
442
|
+
return blackjax.adaptive_tempered_smc(
|
|
443
|
+
prior_log_prob,
|
|
444
|
+
loglikelihood,
|
|
445
|
+
mcmc_kernel.build_kernel(),
|
|
446
|
+
mcmc_kernel.init,
|
|
447
|
+
mcmc_parameters=kernel_parameters,
|
|
448
|
+
resampling_fn=systematic,
|
|
449
|
+
target_ess=target_ess,
|
|
450
|
+
num_mcmc_steps=num_mcmc_steps,
|
|
451
|
+
)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import pymc as pm
|
|
4
|
+
|
|
5
|
+
from pymc_extras.model_builder import ModelBuilder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class LinearModel(ModelBuilder):
|
|
9
|
+
def __init__(
|
|
10
|
+
self, model_config: dict | None = None, sampler_config: dict | None = None, nsamples=100
|
|
11
|
+
):
|
|
12
|
+
self.nsamples = nsamples
|
|
13
|
+
super().__init__(model_config, sampler_config)
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
This class is an implementation of a single-input linear regression model in PYMC using the
|
|
17
|
+
BayesianEstimator base class for interoperability with scikit-learn.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
_model_type = "LinearModel"
|
|
21
|
+
version = "0.1"
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def get_default_model_config():
|
|
25
|
+
return {
|
|
26
|
+
"intercept": {"loc": 0, "scale": 10},
|
|
27
|
+
"slope": {"loc": 0, "scale": 10},
|
|
28
|
+
"obs_error": 2,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def get_default_sampler_config():
|
|
33
|
+
return {
|
|
34
|
+
"draws": 1_000,
|
|
35
|
+
"tune": 1_000,
|
|
36
|
+
"chains": 3,
|
|
37
|
+
"target_accept": 0.95,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def _serializable_model_config(self) -> dict:
|
|
42
|
+
return self.model_config
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def output_var(self):
|
|
46
|
+
return "y_hat"
|
|
47
|
+
|
|
48
|
+
def build_model(self, X: pd.DataFrame, y: pd.Series):
|
|
49
|
+
"""
|
|
50
|
+
Build the PyMC model.
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
None
|
|
55
|
+
|
|
56
|
+
Examples
|
|
57
|
+
--------
|
|
58
|
+
>>> self.build_model()
|
|
59
|
+
>>> assert self.model is not None
|
|
60
|
+
>>> assert isinstance(self.model, pm.Model)
|
|
61
|
+
>>> assert "intercept" in self.model.named_vars
|
|
62
|
+
>>> assert "slope" in self.model.named_vars
|
|
63
|
+
>>> assert "σ_model_fmc" in self.model.named_vars
|
|
64
|
+
>>> assert "y_model" in self.model.named_vars
|
|
65
|
+
>>> assert "y_hat" in self.model.named_vars
|
|
66
|
+
>>> assert self.output_var == "y_hat"
|
|
67
|
+
"""
|
|
68
|
+
cfg = self.model_config
|
|
69
|
+
|
|
70
|
+
# Data array size can change but number of dimensions must stay the same.
|
|
71
|
+
with pm.Model() as self.model:
|
|
72
|
+
x = pm.Data("x", np.zeros((1,)), dims="observation")
|
|
73
|
+
y_data = pm.Data("y_data", np.zeros((1,)), dims="observation")
|
|
74
|
+
|
|
75
|
+
# priors
|
|
76
|
+
intercept = pm.Normal(
|
|
77
|
+
"intercept", cfg["intercept"]["loc"], sigma=cfg["intercept"]["scale"]
|
|
78
|
+
)
|
|
79
|
+
slope = pm.Normal("slope", cfg["slope"]["loc"], sigma=cfg["slope"]["scale"])
|
|
80
|
+
obs_error = pm.HalfNormal("σ_model_fmc", cfg["obs_error"])
|
|
81
|
+
|
|
82
|
+
# Model
|
|
83
|
+
y_model = pm.Deterministic("y_model", intercept + slope * x, dims="observation")
|
|
84
|
+
|
|
85
|
+
# observed data
|
|
86
|
+
pm.Normal(
|
|
87
|
+
"y_hat",
|
|
88
|
+
y_model,
|
|
89
|
+
sigma=obs_error,
|
|
90
|
+
shape=x.shape,
|
|
91
|
+
observed=y_data,
|
|
92
|
+
dims="observation",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self._data_setter(X, y)
|
|
96
|
+
|
|
97
|
+
def _data_setter(self, X: pd.DataFrame, y: pd.DataFrame | pd.Series | None = None):
|
|
98
|
+
with self.model:
|
|
99
|
+
pm.set_data({"x": X.squeeze()})
|
|
100
|
+
if y is not None:
|
|
101
|
+
pm.set_data({"y_data": y.squeeze()})
|
|
102
|
+
|
|
103
|
+
def _generate_and_preprocess_model_data(
|
|
104
|
+
self, X: pd.DataFrame | pd.Series, y: pd.Series
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Generate model data for linear regression.
|
|
108
|
+
|
|
109
|
+
Parameters
|
|
110
|
+
----------
|
|
111
|
+
nsamples : int, optional
|
|
112
|
+
The number of samples to generate. Default is 100.
|
|
113
|
+
data : np.ndarray, optional
|
|
114
|
+
An optional data array to add noise to.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
tuple
|
|
119
|
+
A tuple of two np.ndarrays representing the feature matrix and target vector, respectively.
|
|
120
|
+
|
|
121
|
+
Examples
|
|
122
|
+
--------
|
|
123
|
+
>>> import numpy as np
|
|
124
|
+
>>> x, y = cls.generate_model_data()
|
|
125
|
+
>>> assert isinstance(x, np.ndarray)
|
|
126
|
+
>>> assert isinstance(y, np.ndarray)
|
|
127
|
+
>>> assert x.shape == (100, 1)
|
|
128
|
+
>>> assert y.shape == (100,)
|
|
129
|
+
"""
|
|
130
|
+
self.X, self.y = X, y
|
|
File without changes
|
|
File without changes
|