pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,570 @@
1
+ # Copyright 2024 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from functools import reduce
19
+ from itertools import product
20
+ from typing import Literal
21
+
22
+ import arviz as az
23
+ import numpy as np
24
+ import pymc as pm
25
+ import pytensor
26
+ import pytensor.tensor as pt
27
+ import xarray as xr
28
+
29
+ from arviz import dict_to_dataset
30
+ from better_optimize.constants import minimize_method
31
+ from pymc import DictToArrayBijection
32
+ from pymc.backends.arviz import (
33
+ coords_and_dims_for_inferencedata,
34
+ find_constants,
35
+ find_observations,
36
+ )
37
+ from pymc.blocking import RaveledVars
38
+ from pymc.model.transform.conditioning import remove_value_transforms
39
+ from pymc.model.transform.optimization import freeze_dims_and_data
40
+ from pymc.util import get_default_varnames
41
+ from scipy import stats
42
+
43
+ from pymc_extras.inference.find_map import (
44
+ GradientBackend,
45
+ _unconstrained_vector_to_constrained_rvs,
46
+ find_MAP,
47
+ get_nearest_psd,
48
+ scipy_optimize_funcs_from_loss,
49
+ )
50
+
51
+ _log = logging.getLogger(__name__)
52
+
53
+
54
+ def laplace_draws_to_inferencedata(
55
+ posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
56
+ ) -> az.InferenceData:
57
+ """
58
+ Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object.
59
+
60
+
61
+ Parameters
62
+ ----------
63
+ posterior_draws: list of np.ndarray
64
+ A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where
65
+ shape is the shape of the variable in the posterior.
66
+ model: Model, optional
67
+ A PyMC model. If None, the model is taken from the current model context.
68
+
69
+ Returns
70
+ -------
71
+ idata: az.InferenceData
72
+ An InferenceData object containing the approximated posterior samples
73
+ """
74
+ model = pm.modelcontext(model)
75
+ chains, draws, *_ = posterior_draws[0].shape
76
+
77
+ def make_rv_coords(name):
78
+ coords = {"chain": range(chains), "draw": range(draws)}
79
+ extra_dims = model.named_vars_to_dims.get(name)
80
+ if extra_dims is None:
81
+ return coords
82
+ return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
83
+
84
+ def make_rv_dims(name):
85
+ dims = ["chain", "draw"]
86
+ extra_dims = model.named_vars_to_dims.get(name)
87
+ if extra_dims is None:
88
+ return dims
89
+ return dims + list(extra_dims)
90
+
91
+ names = [
92
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
93
+ ]
94
+ idata = {
95
+ name: xr.DataArray(
96
+ data=draws,
97
+ coords=make_rv_coords(name),
98
+ dims=make_rv_dims(name),
99
+ name=name,
100
+ )
101
+ for name, draws in zip(names, posterior_draws)
102
+ }
103
+
104
+ coords, dims = coords_and_dims_for_inferencedata(model)
105
+ idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
106
+
107
+ return idata
108
+
109
+
110
+ def add_fit_to_inferencedata(
111
+ idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
112
+ ) -> az.InferenceData:
113
+ """
114
+ Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
115
+
116
+
117
+ Parameters
118
+ ----------
119
+ idata: az.InfereceData
120
+ An InferenceData object containing the approximated posterior samples.
121
+ mu: RaveledVars
122
+ The MAP estimate of the model parameters.
123
+ H_inv: np.ndarray
124
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
125
+ model: Model, optional
126
+ A PyMC model. If None, the model is taken from the current model context.
127
+
128
+ Returns
129
+ -------
130
+ idata: az.InferenceData
131
+ The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group.
132
+ """
133
+ model = pm.modelcontext(model)
134
+ coords = model.coords
135
+
136
+ variable_names, *_ = zip(*mu.point_map_info)
137
+
138
+ def make_unpacked_variable_names(name):
139
+ value_to_dim = {
140
+ x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None)
141
+ for x in model.value_vars
142
+ }
143
+ value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None}
144
+
145
+ rv_to_dim = model.named_vars_to_dims
146
+ dims_dict = rv_to_dim | value_to_dim
147
+
148
+ dims = dims_dict.get(name)
149
+ if dims is None:
150
+ return [name]
151
+ labels = product(*(coords[dim] for dim in dims))
152
+ return [f"{name}[{','.join(map(str, label))}]" for label in labels]
153
+
154
+ unpacked_variable_names = reduce(
155
+ lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, []
156
+ )
157
+
158
+ mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
159
+ cov_dataarray = xr.DataArray(
160
+ H_inv,
161
+ dims=["rows", "columns"],
162
+ coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
163
+ )
164
+
165
+ dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
166
+ idata.add_groups(fit=dataset)
167
+
168
+ return idata
169
+
170
+
171
+ def add_data_to_inferencedata(
172
+ idata: az.InferenceData,
173
+ progressbar: bool = True,
174
+ model: pm.Model | None = None,
175
+ compile_kwargs: dict | None = None,
176
+ ) -> az.InferenceData:
177
+ """
178
+ Add observed and constant data to an InferenceData object.
179
+
180
+ Parameters
181
+ ----------
182
+ idata: az.InferenceData
183
+ An InferenceData object containing the approximated posterior samples.
184
+ progressbar: bool
185
+ Whether to display a progress bar during computations. Default is True.
186
+ model: Model, optional
187
+ A PyMC model. If None, the model is taken from the current model context.
188
+ compile_kwargs: dict, optional
189
+ Additional keyword arguments to pass to pytensor.function.
190
+
191
+ Returns
192
+ -------
193
+ idata: az.InferenceData
194
+ The provided InferenceData, with observed and constant data added.
195
+ """
196
+ model = pm.modelcontext(model)
197
+
198
+ if model.deterministics:
199
+ idata.posterior = pm.compute_deterministics(
200
+ idata.posterior,
201
+ model=model,
202
+ merge_dataset=True,
203
+ progressbar=progressbar,
204
+ compile_kwargs=compile_kwargs,
205
+ )
206
+
207
+ coords, dims = coords_and_dims_for_inferencedata(model)
208
+
209
+ observed_data = dict_to_dataset(
210
+ find_observations(model),
211
+ library=pm,
212
+ coords=coords,
213
+ dims=dims,
214
+ default_dims=[],
215
+ )
216
+
217
+ constant_data = dict_to_dataset(
218
+ find_constants(model),
219
+ library=pm,
220
+ coords=coords,
221
+ dims=dims,
222
+ default_dims=[],
223
+ )
224
+
225
+ idata.add_groups(
226
+ {"observed_data": observed_data, "constant_data": constant_data},
227
+ coords=coords,
228
+ dims=dims,
229
+ )
230
+
231
+ return idata
232
+
233
+
234
+ def fit_mvn_to_MAP(
235
+ optimized_point: dict[str, np.ndarray],
236
+ model: pm.Model | None = None,
237
+ on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
238
+ transform_samples: bool = False,
239
+ gradient_backend: GradientBackend = "pytensor",
240
+ zero_tol: float = 1e-8,
241
+ diag_jitter: float | None = 1e-8,
242
+ compile_kwargs: dict | None = None,
243
+ ) -> tuple[RaveledVars, np.ndarray]:
244
+ """
245
+ Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior
246
+ evaluated at the MAP estimate. This is the basis of the Laplace approximation.
247
+
248
+ Parameters
249
+ ----------
250
+ optimized_point : dict[str, np.ndarray]
251
+ Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map
252
+ model : Model, optional
253
+ A PyMC model. If None, the model is taken from the current model context.
254
+ on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
255
+ What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
256
+ If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
257
+ If 'error', an error will be raised.
258
+ transform_samples : bool
259
+ Whether to transform the samples back to the original parameter space. Default is True.
260
+ gradient_backend: str, default "pytensor"
261
+ The backend to use for gradient computations. Must be one of "pytensor" or "jax".
262
+ zero_tol: float
263
+ Value below which an element of the Hessian matrix is counted as 0.
264
+ This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
265
+ diag_jitter: float | None
266
+ A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
267
+ If None, no jitter is added. Default is 1e-8.
268
+ compile_kwargs: dict, optional
269
+ Additional keyword arguments to pass to pytensor.function when compiling loss functions
270
+
271
+ Returns
272
+ -------
273
+ map_estimate: RaveledVars
274
+ The MAP estimate of the model parameters, raveled into a 1D array.
275
+
276
+ inverse_hessian: np.ndarray
277
+ The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
278
+ """
279
+ model = pm.modelcontext(model)
280
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
281
+ frozen_model = freeze_dims_and_data(model)
282
+
283
+ if not transform_samples:
284
+ untransformed_model = remove_value_transforms(frozen_model)
285
+ logp = untransformed_model.logp(jacobian=False)
286
+ variables = untransformed_model.continuous_value_vars
287
+ else:
288
+ logp = frozen_model.logp(jacobian=True)
289
+ variables = frozen_model.continuous_value_vars
290
+
291
+ variable_names = {var.name for var in variables}
292
+ optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names}
293
+ mu = DictToArrayBijection.map(optimized_free_params)
294
+
295
+ _, f_hess, _ = scipy_optimize_funcs_from_loss(
296
+ loss=-logp,
297
+ inputs=variables,
298
+ initial_point_dict=optimized_free_params,
299
+ use_grad=True,
300
+ use_hess=True,
301
+ use_hessp=False,
302
+ gradient_backend=gradient_backend,
303
+ compile_kwargs=compile_kwargs,
304
+ )
305
+
306
+ H = -f_hess(mu.data)
307
+ H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H))
308
+
309
+ def stabilize(x, jitter):
310
+ return x + np.eye(x.shape[0]) * jitter
311
+
312
+ H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter)
313
+
314
+ try:
315
+ np.linalg.cholesky(H_inv)
316
+ except np.linalg.LinAlgError:
317
+ if on_bad_cov == "error":
318
+ raise np.linalg.LinAlgError(
319
+ "Inverse Hessian not positive-semi definite at the provided point"
320
+ )
321
+ H_inv = get_nearest_psd(H_inv)
322
+ if on_bad_cov == "warn":
323
+ _log.warning(
324
+ "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
325
+ "matrix in L1-norm instead"
326
+ )
327
+
328
+ return mu, H_inv
329
+
330
+
331
+ def sample_laplace_posterior(
332
+ mu: RaveledVars,
333
+ H_inv: np.ndarray,
334
+ model: pm.Model | None = None,
335
+ chains: int = 2,
336
+ draws: int = 500,
337
+ transform_samples: bool = False,
338
+ progressbar: bool = True,
339
+ random_seed: int | np.random.Generator | None = None,
340
+ compile_kwargs: dict | None = None,
341
+ ) -> az.InferenceData:
342
+ """
343
+ Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`.
344
+
345
+ Parameters
346
+ ----------
347
+ mu
348
+ H_inv
349
+ model : Model
350
+ A PyMC model
351
+ chains : int
352
+ The number of sampling chains running in parallel. Default is 2.
353
+ draws : int
354
+ The number of samples to draw from the approximated posterior. Default is 500.
355
+ transform_samples : bool
356
+ Whether to transform the samples back to the original parameter space. Default is True.
357
+ progressbar : bool
358
+ Whether to display a progress bar during computations. Default is True.
359
+ random_seed: int | np.random.Generator | None
360
+ Seed for the random number generator or a numpy Generator for reproducibility
361
+
362
+ Returns
363
+ -------
364
+ idata: az.InferenceData
365
+ An InferenceData object containing the approximated posterior samples.
366
+ """
367
+ model = pm.modelcontext(model)
368
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
369
+ rng = np.random.default_rng(random_seed)
370
+
371
+ posterior_dist = stats.multivariate_normal(
372
+ mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
373
+ )
374
+ posterior_draws = posterior_dist.rvs(size=(chains, draws))
375
+
376
+ if transform_samples:
377
+ constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
378
+ batched_values = pt.tensor(
379
+ "batched_values",
380
+ shape=(chains, draws, *unconstrained_vector.type.shape),
381
+ dtype=unconstrained_vector.type.dtype,
382
+ )
383
+ batched_rvs = pytensor.graph.vectorize_graph(
384
+ constrained_rvs, replace={unconstrained_vector: batched_values}
385
+ )
386
+
387
+ f_constrain = pm.compile_pymc(
388
+ inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389
+ )
390
+ posterior_draws = f_constrain(posterior_draws)
391
+
392
+ else:
393
+ info = mu.point_map_info
394
+ flat_shapes = [size for _, _, size, _ in info]
395
+ slices = [
396
+ slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes))
397
+ ]
398
+
399
+ posterior_draws = [
400
+ posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype)
401
+ for idx, (name, shape, _, dtype) in zip(slices, info)
402
+ ]
403
+
404
+ idata = laplace_draws_to_inferencedata(posterior_draws, model)
405
+ idata = add_fit_to_inferencedata(idata, mu, H_inv)
406
+ idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
407
+
408
+ return idata
409
+
410
+
411
+ def fit_laplace(
412
+ optimize_method: minimize_method = "BFGS",
413
+ *,
414
+ model: pm.Model | None = None,
415
+ use_grad: bool | None = None,
416
+ use_hessp: bool | None = None,
417
+ use_hess: bool | None = None,
418
+ initvals: dict | None = None,
419
+ random_seed: int | np.random.Generator | None = None,
420
+ return_raw: bool = False,
421
+ jitter_rvs: list[pt.TensorVariable] | None = None,
422
+ progressbar: bool = True,
423
+ include_transformed: bool = True,
424
+ gradient_backend: GradientBackend = "pytensor",
425
+ chains: int = 2,
426
+ draws: int = 500,
427
+ on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
428
+ fit_in_unconstrained_space: bool = False,
429
+ zero_tol: float = 1e-8,
430
+ diag_jitter: float | None = 1e-8,
431
+ optimizer_kwargs: dict | None = None,
432
+ compile_kwargs: dict | None = None,
433
+ ) -> az.InferenceData:
434
+ """
435
+ Create a Laplace (quadratic) approximation for a posterior distribution.
436
+
437
+ This function generates a Laplace approximation for a given posterior distribution using a specified
438
+ number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
439
+ that can be used for further analysis.
440
+
441
+ Parameters
442
+ ----------
443
+ model : pm.Model
444
+ The PyMC model to be fit. If None, the current model context is used.
445
+ optimize_method : str
446
+ The optimization method to use. See scipy.optimize.minimize documentation for details.
447
+ use_grad : bool | None, optional
448
+ Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
449
+ the ``method``.
450
+ use_hessp : bool | None, optional
451
+ Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
452
+ the ``method``.
453
+ use_hess : bool | None, optional
454
+ Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
455
+ the ``method``.
456
+ initvals : None | dict, optional
457
+ Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
458
+ If None, the model's default initial values are used.
459
+ random_seed : None | int | np.random.Generator, optional
460
+ Seed for the random number generator or a numpy Generator for reproducibility
461
+ return_raw: bool | False, optinal
462
+ Whether to also return the full output of `scipy.optimize.minimize`
463
+ jitter_rvs : list of TensorVariables, optional
464
+ Variables whose initial values should be jittered. If None, all variables are jittered.
465
+ progressbar : bool, optional
466
+ Whether to display a progress bar during optimization. Defaults to True.
467
+ fit_in_unconstrained_space: bool, default False
468
+ Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
469
+ from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
470
+ then be transformed back to the original parameter space. This will guarantee that the samples will respect
471
+ the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
472
+ and 1).
473
+
474
+ .. warning::
475
+ This argumnet should be considered highly experimental. It has not been verified if this method produces
476
+ valid draws from the posterior. **Use at your own risk**.
477
+
478
+ gradient_backend: str, default "pytensor"
479
+ The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480
+ chains: int, default: 2
481
+ The number of sampling chains running in parallel.
482
+ draws: int, default: 500
483
+ The number of samples to draw from the approximated posterior.
484
+ on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485
+ What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486
+ If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
487
+ If 'error', an error will be raised.
488
+ zero_tol: float
489
+ Value below which an element of the Hessian matrix is counted as 0.
490
+ This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8.
491
+ diag_jitter: float | None
492
+ A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
493
+ If None, no jitter is added. Default is 1e-8.
494
+ optimizer_kwargs: dict, optional
495
+ Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
496
+ details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
497
+ to use a nested dictionary.
498
+ compile_kwargs: dict, optional
499
+ Additional keyword arguments to pass to pytensor.function.
500
+
501
+ Returns
502
+ -------
503
+ idata: az.InferenceData
504
+ An InferenceData object containing the approximated posterior samples.
505
+
506
+ Examples
507
+ --------
508
+ >>> from pymc_extras.inference.laplace import fit_laplace
509
+ >>> import numpy as np
510
+ >>> import pymc as pm
511
+ >>> import arviz as az
512
+ >>> y = np.array([2642, 3503, 4358]*10)
513
+ >>> with pm.Model() as m:
514
+ >>> logsigma = pm.Uniform("logsigma", 1, 100)
515
+ >>> mu = pm.Uniform("mu", -10000, 10000)
516
+ >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
517
+ >>> idata = fit_laplace()
518
+
519
+ Notes
520
+ -----
521
+ This method of approximation may not be suitable for all types of posterior distributions,
522
+ especially those with significant skewness or multimodality.
523
+
524
+ See Also
525
+ --------
526
+ fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
527
+ will forward the call to 'fit_laplace'.
528
+
529
+ """
530
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
531
+ optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
532
+
533
+ optimized_point = find_MAP(
534
+ method=optimize_method,
535
+ model=model,
536
+ use_grad=use_grad,
537
+ use_hessp=use_hessp,
538
+ use_hess=use_hess,
539
+ initvals=initvals,
540
+ random_seed=random_seed,
541
+ return_raw=return_raw,
542
+ jitter_rvs=jitter_rvs,
543
+ progressbar=progressbar,
544
+ include_transformed=include_transformed,
545
+ gradient_backend=gradient_backend,
546
+ compile_kwargs=compile_kwargs,
547
+ **optimizer_kwargs,
548
+ )
549
+
550
+ mu, H_inv = fit_mvn_to_MAP(
551
+ optimized_point=optimized_point,
552
+ model=model,
553
+ on_bad_cov=on_bad_cov,
554
+ transform_samples=fit_in_unconstrained_space,
555
+ zero_tol=zero_tol,
556
+ diag_jitter=diag_jitter,
557
+ compile_kwargs=compile_kwargs,
558
+ )
559
+
560
+ return sample_laplace_posterior(
561
+ mu=mu,
562
+ H_inv=H_inv,
563
+ model=model,
564
+ chains=chains,
565
+ draws=draws,
566
+ transform_samples=fit_in_unconstrained_space,
567
+ progressbar=progressbar,
568
+ random_seed=random_seed,
569
+ compile_kwargs=compile_kwargs,
570
+ )
@@ -0,0 +1,134 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+ import sys
17
+
18
+ import arviz as az
19
+ import blackjax
20
+ import jax
21
+ import numpy as np
22
+ import pymc as pm
23
+
24
+ from packaging import version
25
+ from pymc.backends.arviz import coords_and_dims_for_inferencedata
26
+ from pymc.blocking import DictToArrayBijection, RaveledVars
27
+ from pymc.model import modelcontext
28
+ from pymc.sampling.jax import get_jaxified_graph
29
+ from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
30
+
31
+
32
+ def convert_flat_trace_to_idata(
33
+ samples,
34
+ include_transformed=False,
35
+ postprocessing_backend="cpu",
36
+ model=None,
37
+ ):
38
+ model = modelcontext(model)
39
+ ip = model.initial_point()
40
+ ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info
41
+ trace = collections.defaultdict(list)
42
+ for sample in samples:
43
+ raveld_vars = RaveledVars(sample, ip_point_map_info)
44
+ point = DictToArrayBijection.rmap(raveld_vars, ip)
45
+ for p, v in point.items():
46
+ trace[p].append(v.tolist())
47
+
48
+ trace = {k: np.asarray(v)[None, ...] for k, v in trace.items()}
49
+
50
+ var_names = model.unobserved_value_vars
51
+ vars_to_sample = list(get_default_varnames(var_names, include_transformed=include_transformed))
52
+ print("Transforming variables...", file=sys.stdout)
53
+ jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
54
+ result = jax.vmap(jax.vmap(jax_fn))(
55
+ *jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
56
+ )
57
+ trace = {v.name: r for v, r in zip(vars_to_sample, result)}
58
+ coords, dims = coords_and_dims_for_inferencedata(model)
59
+ idata = az.from_dict(trace, dims=dims, coords=coords)
60
+
61
+ return idata
62
+
63
+
64
+ def fit_pathfinder(
65
+ samples=1000,
66
+ random_seed: RandomSeed | None = None,
67
+ postprocessing_backend="cpu",
68
+ model=None,
69
+ **pathfinder_kwargs,
70
+ ):
71
+ """
72
+ Fit the pathfinder algorithm as implemented in blackjax
73
+
74
+ Requires the JAX backend
75
+
76
+ Parameters
77
+ ----------
78
+ samples : int
79
+ Number of samples to draw from the fitted approximation.
80
+ random_seed : int
81
+ Random seed to set.
82
+ postprocessing_backend : str
83
+ Where to compute transformations of the trace.
84
+ "cpu" or "gpu".
85
+ pathfinder_kwargs:
86
+ kwargs for blackjax.vi.pathfinder.approximate
87
+
88
+ Returns
89
+ -------
90
+ arviz.InferenceData
91
+
92
+ Reference
93
+ ---------
94
+ https://arxiv.org/abs/2108.03782
95
+ """
96
+ # Temporarily helper
97
+ if version.parse(blackjax.__version__).major < 1:
98
+ raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
99
+
100
+ model = modelcontext(model)
101
+
102
+ ip = model.initial_point()
103
+ ip_map = DictToArrayBijection.map(ip)
104
+
105
+ new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
106
+ ip, (model.logp(),), model.value_vars, ()
107
+ )
108
+
109
+ logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
110
+
111
+ def logprob_fn(x):
112
+ return logprob_fn_list(x)[0]
113
+
114
+ [pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2)
115
+
116
+ print("Running pathfinder...", file=sys.stdout)
117
+ pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
118
+ rng_key=jax.random.key(pathfinder_seed),
119
+ logdensity_fn=logprob_fn,
120
+ initial_position=ip_map.data,
121
+ **pathfinder_kwargs,
122
+ )
123
+ samples, _ = blackjax.vi.pathfinder.sample(
124
+ rng_key=jax.random.key(sample_seed),
125
+ state=pathfinder_state,
126
+ num_samples=samples,
127
+ )
128
+
129
+ idata = convert_flat_trace_to_idata(
130
+ samples,
131
+ postprocessing_backend=postprocessing_backend,
132
+ model=model,
133
+ )
134
+ return idata