pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,13 @@
1
+ # Copyright 2023 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,451 @@
1
+ # Copyright 2023 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import time
17
+ import warnings
18
+
19
+ from collections.abc import Callable
20
+ from typing import NamedTuple, cast
21
+
22
+ import arviz as az
23
+ import blackjax
24
+ import jax
25
+ import jax.numpy as jnp
26
+ import numpy as np
27
+
28
+ from blackjax.smc import extend_params
29
+ from blackjax.smc.resampling import systematic
30
+ from pymc import draw, modelcontext, to_inference_data
31
+ from pymc.backends import NDArray
32
+ from pymc.backends.base import MultiTrace
33
+ from pymc.initial_point import make_initial_point_expression
34
+ from pymc.sampling.jax import get_jaxified_graph
35
+ from pymc.util import RandomState, _get_seeds_per_chain
36
+
37
+ log = logging.getLogger(__name__)
38
+
39
+
40
+ def sample_smc_blackjax(
41
+ n_particles: int = 2000,
42
+ random_seed: RandomState = None,
43
+ kernel: str = "HMC",
44
+ target_essn: float = 0.5,
45
+ num_mcmc_steps: int = 10,
46
+ inner_kernel_params: dict | None = None,
47
+ model=None,
48
+ iterations_to_diagnose: int = 100,
49
+ ):
50
+ """Samples using BlackJax's implementation of Sequential Monte Carlo.
51
+
52
+ Parameters
53
+ ----------
54
+ n_particles: int
55
+ number of particles used to sample from the posterior. This is also the number of draws. Defaults to 2000.
56
+ random_seed: RandomState
57
+ seed used for random number generator, set for reproducibility. Otherwise a random one will be used (default).
58
+ kernel: str
59
+ Either 'HMC' (default) or 'NUTS'. The kernel to be used to mutate the particles in each SMC iteration.
60
+ target_essn: float
61
+ Proportion (0 < target_essn < 1) of the total number of particles, to be used for incrementing the exponent
62
+ of the tempered posterior between iterations. The higher the number, each increment is going to be smaller,
63
+ leading to more steps and computational cost. Defaults to 0.5. See https://arxiv.org/abs/1602.03572
64
+ num_mcmc_steps: int
65
+ fixed number of steps of each inner kernel markov chain for each SMC mutation step.
66
+ inner_kernel_params: Optional[dict]
67
+ a dictionary with parameters for the inner kernel.
68
+ For HMC it must have 'step_size' and 'integration_steps'
69
+ For NUTS it must have 'step_size'
70
+ these parameters are fixed for all iterations.
71
+ model:
72
+ PyMC model to sample from
73
+ iterations_to_diagnose: int
74
+ Number of iterations to generate diagnosis for. By default, will diagnose the first 100 iterations. Increase
75
+ this number for further diagnosis (it can be bigger than the actual number of iterations executed by the algorithm,
76
+ at the expense of allocating memory to store the diagnosis).
77
+
78
+ Returns
79
+ -------
80
+ An Arviz Inference data.
81
+
82
+ Note
83
+ ----
84
+ A summary of the algorithm is:
85
+
86
+ 1. Initialize :math:`\beta` at zero and stage at zero.
87
+ 2. Generate N samples :math:`S_{\beta}` from the prior (because when :math `\beta = 0` the
88
+ tempered posterior is the prior).
89
+ 3. Increase :math:`\beta` in order to make the effective sample size equal some predefined
90
+ value (target_essn)
91
+ 4. Compute a set of N importance weights W. The weights are computed as the ratio of the
92
+ likelihoods of a sample at stage i+1 and stage i.
93
+ 5. Obtain :math:`S_{w}` by re-sampling according to W.
94
+ 6. Run N independent MCMC chains, starting each one from a different sample
95
+ in :math:`S_{w}`. For that, set the kernel and inner_kernel_params.
96
+ 7. The N chains are run for num_mcmc_steps each.
97
+ 8. Repeat from step 3 until :math:`\beta \\ge 1`.
98
+ 9. The final result is a collection of N samples from the posterior
99
+
100
+ """
101
+
102
+ model = modelcontext(model)
103
+ random_seed = np.random.default_rng(seed=random_seed)
104
+
105
+ if inner_kernel_params is None:
106
+ inner_kernel_params = {}
107
+
108
+ log.info(
109
+ f"Will only diagnose the first {iterations_to_diagnose} SMC iterations,"
110
+ f"this number can be increased by setting iterations_to_diagnose parameter"
111
+ f" in sample_with_blackjax_smc"
112
+ )
113
+
114
+ key = jax.random.PRNGKey(_get_seeds_per_chain(random_seed, 1)[0])
115
+
116
+ key, initial_particles_key, iterations_key = jax.random.split(key, 3)
117
+
118
+ initial_particles = blackjax_particles_from_pymc_population(
119
+ model, initialize_population(model, n_particles, random_seed)
120
+ )
121
+
122
+ var_map = var_map_from_model(
123
+ model, model.initial_point(random_seed=random_seed.integers(2**30))
124
+ )
125
+
126
+ posterior_dimensions = sum(var_map[k][1] for k in var_map)
127
+
128
+ if kernel == "HMC":
129
+ mcmc_kernel = blackjax.mcmc.hmc
130
+ mcmc_parameters = extend_params(
131
+ dict(
132
+ step_size=inner_kernel_params["step_size"],
133
+ inverse_mass_matrix=jnp.eye(posterior_dimensions),
134
+ num_integration_steps=inner_kernel_params["integration_steps"],
135
+ )
136
+ )
137
+ elif kernel == "NUTS":
138
+ mcmc_kernel = blackjax.mcmc.nuts
139
+ mcmc_parameters = extend_params(
140
+ dict(
141
+ step_size=inner_kernel_params["step_size"],
142
+ inverse_mass_matrix=jnp.eye(posterior_dimensions),
143
+ )
144
+ )
145
+ else:
146
+ raise ValueError(f"Invalid kernel {kernel}, valid options are 'HMC' and 'NUTS'")
147
+
148
+ sampler = build_smc_with_kernel(
149
+ prior_log_prob=get_jaxified_logprior(model),
150
+ loglikelihood=get_jaxified_loglikelihood(model),
151
+ target_ess=target_essn,
152
+ num_mcmc_steps=num_mcmc_steps,
153
+ kernel_parameters=mcmc_parameters,
154
+ mcmc_kernel=mcmc_kernel,
155
+ )
156
+
157
+ start = time.time()
158
+ total_iterations, particles, diagnosis = inference_loop(
159
+ iterations_key,
160
+ sampler.init(initial_particles),
161
+ sampler,
162
+ iterations_to_diagnose,
163
+ n_particles,
164
+ )
165
+ end = time.time()
166
+ running_time = end - start
167
+
168
+ inference_data = arviz_from_particles(model, particles)
169
+
170
+ add_to_inference_data(
171
+ inference_data,
172
+ n_particles,
173
+ target_essn,
174
+ num_mcmc_steps,
175
+ kernel,
176
+ diagnosis,
177
+ total_iterations,
178
+ iterations_to_diagnose,
179
+ inner_kernel_params,
180
+ running_time,
181
+ )
182
+
183
+ if total_iterations < iterations_to_diagnose:
184
+ log.warning(
185
+ f"Only the first {iterations_to_diagnose} were included in diagnosed quantities out of {total_iterations}."
186
+ )
187
+
188
+ return inference_data
189
+
190
+
191
+ def arviz_from_particles(model, particles):
192
+ """
193
+ Given Particles in Blackjax format,
194
+ builds an Arviz Inference Data object.
195
+ In order to do so in a consistent way,
196
+ particles are assumed to be encoded in
197
+ model.value_vars order.
198
+
199
+ Parameters
200
+ ----------
201
+ model: Pymc Model
202
+ particles: output of Blackjax SMC.
203
+
204
+
205
+ Returns an Arviz Inference Data Object
206
+ -------
207
+ """
208
+ n_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0]
209
+ by_varname = {
210
+ k.name: v.squeeze()[np.newaxis, :].astype(k.dtype)
211
+ for k, v in zip(model.value_vars, particles)
212
+ }
213
+ varnames = [v.name for v in model.value_vars]
214
+ with model:
215
+ strace = NDArray(name=model.name)
216
+ strace.setup(n_particles, 0)
217
+ for particle_index in range(0, n_particles):
218
+ strace.record(point={k: np.asarray(by_varname[k][0][particle_index]) for k in varnames})
219
+ multitrace = MultiTrace((strace,))
220
+ return to_inference_data(multitrace, log_likelihood=False)
221
+
222
+
223
+ class SMCDiagnostics(NamedTuple):
224
+ """
225
+ A Jax-compilable object to track
226
+ quantities of interest of an SMC run.
227
+ Note that initial_diagnosis and update_diagnosis
228
+ must return copies and not modify in place for the class
229
+ to be Jax Compilable, reason why they are static methods.
230
+ """
231
+
232
+ lmbda_evolution: jax.Array
233
+ log_likelihood_increment_evolution: jax.Array
234
+ ancestors_evolution: jax.Array
235
+ weights_evolution: jax.Array
236
+
237
+ @staticmethod
238
+ def update_diagnosis(i, history, info, state):
239
+ le, lli, ancestors, weights_evolution = history
240
+ return SMCDiagnostics(
241
+ le.at[i].set(state.lmbda),
242
+ lli.at[i].set(info.log_likelihood_increment),
243
+ ancestors.at[i].set(info.ancestors),
244
+ weights_evolution.at[i].set(state.weights),
245
+ )
246
+
247
+ @staticmethod
248
+ def initial_diagnosis(iterations_to_diagnose, n_particles):
249
+ return SMCDiagnostics(
250
+ jnp.zeros(iterations_to_diagnose),
251
+ jnp.zeros(iterations_to_diagnose),
252
+ jnp.zeros((iterations_to_diagnose, n_particles)),
253
+ jnp.zeros((iterations_to_diagnose, n_particles)),
254
+ )
255
+
256
+
257
+ def flatten_single_particle(particle):
258
+ return jnp.hstack([v.squeeze() for v in particle])
259
+
260
+
261
+ def inference_loop(rng_key, initial_state, kernel, iterations_to_diagnose, n_particles):
262
+ """
263
+ SMC inference loop that keeps tracks of diagnosis quantities.
264
+ """
265
+
266
+ def cond(carry):
267
+ i, state, _, _ = carry
268
+ return state.lmbda < 1
269
+
270
+ def one_step(carry):
271
+ i, state, k, previous_info = carry
272
+ k, subk = jax.random.split(k, 2)
273
+ state, info = kernel.step(subk, state)
274
+ full_info = SMCDiagnostics.update_diagnosis(i, previous_info, info, state)
275
+
276
+ return i + 1, state, k, full_info
277
+
278
+ n_iter, final_state, _, diagnosis = jax.lax.while_loop(
279
+ cond,
280
+ one_step,
281
+ (
282
+ 0,
283
+ initial_state,
284
+ rng_key,
285
+ SMCDiagnostics.initial_diagnosis(iterations_to_diagnose, n_particles),
286
+ ),
287
+ )
288
+
289
+ return n_iter, final_state.particles, diagnosis
290
+
291
+
292
+ def blackjax_particles_from_pymc_population(model, pymc_population):
293
+ """
294
+ Transforms a pymc population of particles into the format
295
+ accepted by BlackJax. Particles must be a PyTree, each leave represents
296
+ a variable from the posterior, being an array of size n_particles
297
+ * the variable's dimensionality.
298
+ Note that the order in which variables are stored in the Pytree
299
+ must be the same order used to calculate the logprior and loglikelihood.
300
+
301
+ Parameters
302
+ ----------
303
+ pymc_population : A dictionary with variables as keys, and arrays
304
+ with samples as values.
305
+ """
306
+
307
+ order_of_vars = model.value_vars
308
+
309
+ def _format(var):
310
+ variable = pymc_population[var.name]
311
+ if len(variable.shape) == 1:
312
+ return variable[:, np.newaxis]
313
+ else:
314
+ return variable
315
+
316
+ return [_format(var) for var in order_of_vars]
317
+
318
+
319
+ def add_to_inference_data(
320
+ inference_data: az.InferenceData,
321
+ n_particles: int,
322
+ target_ess: float,
323
+ num_mcmc_steps: int,
324
+ kernel: str,
325
+ diagnosis: SMCDiagnostics,
326
+ total_iterations: int,
327
+ iterations_to_diagnose: int,
328
+ kernel_parameters: dict,
329
+ running_time_seconds: float,
330
+ ):
331
+ """
332
+ Adds several SMC parameters into the az.InferenceData result
333
+
334
+ Parameters
335
+ ----------
336
+ inference_data: arviz object to add attributes to.
337
+ n_particles: number of particles present in the result
338
+ target_ess: target effective sampling size between SMC iterations, used
339
+ to calculate the tempering exponent
340
+ num_mcmc_steps: number of steps of the inner kernel when mutating particles
341
+ kernel: string representing the kernel used to mutate particles
342
+ diagnosis: SMCDiagnostics, containing quantities of interest for the full
343
+ SMC run
344
+ total_iterations: the total number of iterations executed by the sampler
345
+ iterations_to_diagnose: the number of iterations represented in the diagnosed
346
+ quantities
347
+ kernel_parameters: dict parameters from the inner kernel used to mutate particles
348
+ running_time_seconds: float sampling time
349
+ """
350
+ experiment_parameters = {
351
+ "particles": n_particles,
352
+ "target_ess": target_ess,
353
+ "num_mcmc_steps": num_mcmc_steps,
354
+ "iterations": total_iterations,
355
+ "iterations_to_diagnose": iterations_to_diagnose,
356
+ "sampler": f"Blackjax SMC with {kernel} kernel",
357
+ }
358
+
359
+ inference_data.posterior.attrs["lambda_evolution"] = np.array(diagnosis.lmbda_evolution)[
360
+ :iterations_to_diagnose
361
+ ]
362
+ inference_data.posterior.attrs["log_likelihood_increments"] = np.array(
363
+ diagnosis.log_likelihood_increment_evolution
364
+ )[:iterations_to_diagnose]
365
+ inference_data.posterior.attrs["ancestors_evolution"] = np.array(diagnosis.ancestors_evolution)[
366
+ :iterations_to_diagnose
367
+ ]
368
+ inference_data.posterior.attrs["weights_evolution"] = np.array(diagnosis.weights_evolution)[
369
+ :iterations_to_diagnose
370
+ ]
371
+
372
+ for k in experiment_parameters:
373
+ inference_data.posterior.attrs[k] = experiment_parameters[k]
374
+
375
+ for k in kernel_parameters:
376
+ inference_data.posterior.attrs[k] = kernel_parameters[k]
377
+
378
+ inference_data.posterior.attrs["running_time_seconds"] = running_time_seconds
379
+
380
+ return inference_data
381
+
382
+
383
+ def get_jaxified_logprior(model) -> Callable:
384
+ return get_jaxified_particles_fn(model, model.varlogp)
385
+
386
+
387
+ def get_jaxified_loglikelihood(model) -> Callable:
388
+ return get_jaxified_particles_fn(model, model.datalogp)
389
+
390
+
391
+ def get_jaxified_particles_fn(model, graph_outputs):
392
+ """
393
+ Builds a Jaxified version of a value_vars function,
394
+ that is applyable to Blackjax particles format.
395
+ """
396
+ logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[graph_outputs])
397
+
398
+ def logp_fn_wrap(particles):
399
+ return logp_fn(*[p.squeeze() for p in particles])[0]
400
+
401
+ return logp_fn_wrap
402
+
403
+
404
+ def initialize_population(model, draws, random_seed) -> dict[str, np.ndarray]:
405
+ with warnings.catch_warnings():
406
+ warnings.filterwarnings("ignore", category=UserWarning, message="The effect of Potentials")
407
+
408
+ prior_expression = make_initial_point_expression(
409
+ free_rvs=model.free_RVs,
410
+ rvs_to_transforms=model.rvs_to_transforms,
411
+ initval_strategies={},
412
+ default_strategy="prior",
413
+ return_transformed=True,
414
+ )
415
+ prior_values = draw(prior_expression, draws=draws, random_seed=random_seed)
416
+
417
+ names = [model.rvs_to_values[rv].name for rv in model.free_RVs]
418
+ dict_prior = {k: np.stack(v) for k, v in zip(names, prior_values)}
419
+
420
+ return cast(dict[str, np.ndarray], dict_prior)
421
+
422
+
423
+ def var_map_from_model(model, initial_point) -> dict:
424
+ """
425
+ Computes a dictionary that maps
426
+ variable names to tuples (shape, size)
427
+ """
428
+ var_info = {}
429
+ for v in model.value_vars:
430
+ var_info[v.name] = (initial_point[v.name].shape, initial_point[v.name].size)
431
+ return var_info
432
+
433
+
434
+ def build_smc_with_kernel(
435
+ prior_log_prob,
436
+ loglikelihood,
437
+ target_ess,
438
+ num_mcmc_steps,
439
+ kernel_parameters,
440
+ mcmc_kernel,
441
+ ):
442
+ return blackjax.adaptive_tempered_smc(
443
+ prior_log_prob,
444
+ loglikelihood,
445
+ mcmc_kernel.build_kernel(),
446
+ mcmc_kernel.init,
447
+ mcmc_parameters=kernel_parameters,
448
+ resampling_fn=systematic,
449
+ target_ess=target_ess,
450
+ num_mcmc_steps=num_mcmc_steps,
451
+ )
@@ -0,0 +1,130 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import pymc as pm
4
+
5
+ from pymc_extras.model_builder import ModelBuilder
6
+
7
+
8
+ class LinearModel(ModelBuilder):
9
+ def __init__(
10
+ self, model_config: dict | None = None, sampler_config: dict | None = None, nsamples=100
11
+ ):
12
+ self.nsamples = nsamples
13
+ super().__init__(model_config, sampler_config)
14
+
15
+ """
16
+ This class is an implementation of a single-input linear regression model in PYMC using the
17
+ BayesianEstimator base class for interoperability with scikit-learn.
18
+ """
19
+
20
+ _model_type = "LinearModel"
21
+ version = "0.1"
22
+
23
+ @staticmethod
24
+ def get_default_model_config():
25
+ return {
26
+ "intercept": {"loc": 0, "scale": 10},
27
+ "slope": {"loc": 0, "scale": 10},
28
+ "obs_error": 2,
29
+ }
30
+
31
+ @staticmethod
32
+ def get_default_sampler_config():
33
+ return {
34
+ "draws": 1_000,
35
+ "tune": 1_000,
36
+ "chains": 3,
37
+ "target_accept": 0.95,
38
+ }
39
+
40
+ @property
41
+ def _serializable_model_config(self) -> dict:
42
+ return self.model_config
43
+
44
+ @property
45
+ def output_var(self):
46
+ return "y_hat"
47
+
48
+ def build_model(self, X: pd.DataFrame, y: pd.Series):
49
+ """
50
+ Build the PyMC model.
51
+
52
+ Returns
53
+ -------
54
+ None
55
+
56
+ Examples
57
+ --------
58
+ >>> self.build_model()
59
+ >>> assert self.model is not None
60
+ >>> assert isinstance(self.model, pm.Model)
61
+ >>> assert "intercept" in self.model.named_vars
62
+ >>> assert "slope" in self.model.named_vars
63
+ >>> assert "σ_model_fmc" in self.model.named_vars
64
+ >>> assert "y_model" in self.model.named_vars
65
+ >>> assert "y_hat" in self.model.named_vars
66
+ >>> assert self.output_var == "y_hat"
67
+ """
68
+ cfg = self.model_config
69
+
70
+ # Data array size can change but number of dimensions must stay the same.
71
+ with pm.Model() as self.model:
72
+ x = pm.Data("x", np.zeros((1,)), dims="observation")
73
+ y_data = pm.Data("y_data", np.zeros((1,)), dims="observation")
74
+
75
+ # priors
76
+ intercept = pm.Normal(
77
+ "intercept", cfg["intercept"]["loc"], sigma=cfg["intercept"]["scale"]
78
+ )
79
+ slope = pm.Normal("slope", cfg["slope"]["loc"], sigma=cfg["slope"]["scale"])
80
+ obs_error = pm.HalfNormal("σ_model_fmc", cfg["obs_error"])
81
+
82
+ # Model
83
+ y_model = pm.Deterministic("y_model", intercept + slope * x, dims="observation")
84
+
85
+ # observed data
86
+ pm.Normal(
87
+ "y_hat",
88
+ y_model,
89
+ sigma=obs_error,
90
+ shape=x.shape,
91
+ observed=y_data,
92
+ dims="observation",
93
+ )
94
+
95
+ self._data_setter(X, y)
96
+
97
+ def _data_setter(self, X: pd.DataFrame, y: pd.DataFrame | pd.Series | None = None):
98
+ with self.model:
99
+ pm.set_data({"x": X.squeeze()})
100
+ if y is not None:
101
+ pm.set_data({"y_data": y.squeeze()})
102
+
103
+ def _generate_and_preprocess_model_data(
104
+ self, X: pd.DataFrame | pd.Series, y: pd.Series
105
+ ) -> None:
106
+ """
107
+ Generate model data for linear regression.
108
+
109
+ Parameters
110
+ ----------
111
+ nsamples : int, optional
112
+ The number of samples to generate. Default is 100.
113
+ data : np.ndarray, optional
114
+ An optional data array to add noise to.
115
+
116
+ Returns
117
+ -------
118
+ tuple
119
+ A tuple of two np.ndarrays representing the feature matrix and target vector, respectively.
120
+
121
+ Examples
122
+ --------
123
+ >>> import numpy as np
124
+ >>> x, y = cls.generate_model_data()
125
+ >>> assert isinstance(x, np.ndarray)
126
+ >>> assert isinstance(y, np.ndarray)
127
+ >>> assert x.shape == (100, 1)
128
+ >>> assert y.shape == (100,)
129
+ """
130
+ self.X, self.y = X, y
File without changes
File without changes