pymc-extras 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. pymc_extras/__init__.py +29 -0
  2. pymc_extras/distributions/__init__.py +40 -0
  3. pymc_extras/distributions/continuous.py +351 -0
  4. pymc_extras/distributions/discrete.py +399 -0
  5. pymc_extras/distributions/histogram_utils.py +163 -0
  6. pymc_extras/distributions/multivariate/__init__.py +3 -0
  7. pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
  8. pymc_extras/distributions/timeseries.py +356 -0
  9. pymc_extras/gp/__init__.py +18 -0
  10. pymc_extras/gp/latent_approx.py +183 -0
  11. pymc_extras/inference/__init__.py +18 -0
  12. pymc_extras/inference/find_map.py +431 -0
  13. pymc_extras/inference/fit.py +44 -0
  14. pymc_extras/inference/laplace.py +570 -0
  15. pymc_extras/inference/pathfinder.py +134 -0
  16. pymc_extras/inference/smc/__init__.py +13 -0
  17. pymc_extras/inference/smc/sampling.py +451 -0
  18. pymc_extras/linearmodel.py +130 -0
  19. pymc_extras/model/__init__.py +0 -0
  20. pymc_extras/model/marginal/__init__.py +0 -0
  21. pymc_extras/model/marginal/distributions.py +276 -0
  22. pymc_extras/model/marginal/graph_analysis.py +372 -0
  23. pymc_extras/model/marginal/marginal_model.py +595 -0
  24. pymc_extras/model/model_api.py +56 -0
  25. pymc_extras/model/transforms/__init__.py +0 -0
  26. pymc_extras/model/transforms/autoreparam.py +434 -0
  27. pymc_extras/model_builder.py +759 -0
  28. pymc_extras/preprocessing/__init__.py +0 -0
  29. pymc_extras/preprocessing/standard_scaler.py +17 -0
  30. pymc_extras/printing.py +182 -0
  31. pymc_extras/statespace/__init__.py +13 -0
  32. pymc_extras/statespace/core/__init__.py +7 -0
  33. pymc_extras/statespace/core/compile.py +48 -0
  34. pymc_extras/statespace/core/representation.py +438 -0
  35. pymc_extras/statespace/core/statespace.py +2268 -0
  36. pymc_extras/statespace/filters/__init__.py +15 -0
  37. pymc_extras/statespace/filters/distributions.py +453 -0
  38. pymc_extras/statespace/filters/kalman_filter.py +820 -0
  39. pymc_extras/statespace/filters/kalman_smoother.py +126 -0
  40. pymc_extras/statespace/filters/utilities.py +59 -0
  41. pymc_extras/statespace/models/ETS.py +670 -0
  42. pymc_extras/statespace/models/SARIMAX.py +536 -0
  43. pymc_extras/statespace/models/VARMAX.py +393 -0
  44. pymc_extras/statespace/models/__init__.py +6 -0
  45. pymc_extras/statespace/models/structural.py +1651 -0
  46. pymc_extras/statespace/models/utilities.py +387 -0
  47. pymc_extras/statespace/utils/__init__.py +0 -0
  48. pymc_extras/statespace/utils/constants.py +74 -0
  49. pymc_extras/statespace/utils/coord_tools.py +0 -0
  50. pymc_extras/statespace/utils/data_tools.py +182 -0
  51. pymc_extras/utils/__init__.py +23 -0
  52. pymc_extras/utils/linear_cg.py +290 -0
  53. pymc_extras/utils/pivoted_cholesky.py +69 -0
  54. pymc_extras/utils/prior.py +200 -0
  55. pymc_extras/utils/spline.py +131 -0
  56. pymc_extras/version.py +11 -0
  57. pymc_extras/version.txt +1 -0
  58. pymc_extras-0.2.0.dist-info/LICENSE +212 -0
  59. pymc_extras-0.2.0.dist-info/METADATA +99 -0
  60. pymc_extras-0.2.0.dist-info/RECORD +101 -0
  61. pymc_extras-0.2.0.dist-info/WHEEL +5 -0
  62. pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
  63. tests/__init__.py +13 -0
  64. tests/distributions/__init__.py +19 -0
  65. tests/distributions/test_continuous.py +185 -0
  66. tests/distributions/test_discrete.py +210 -0
  67. tests/distributions/test_discrete_markov_chain.py +258 -0
  68. tests/distributions/test_multivariate.py +304 -0
  69. tests/model/__init__.py +0 -0
  70. tests/model/marginal/__init__.py +0 -0
  71. tests/model/marginal/test_distributions.py +131 -0
  72. tests/model/marginal/test_graph_analysis.py +182 -0
  73. tests/model/marginal/test_marginal_model.py +867 -0
  74. tests/model/test_model_api.py +29 -0
  75. tests/statespace/__init__.py +0 -0
  76. tests/statespace/test_ETS.py +411 -0
  77. tests/statespace/test_SARIMAX.py +405 -0
  78. tests/statespace/test_VARMAX.py +184 -0
  79. tests/statespace/test_coord_assignment.py +116 -0
  80. tests/statespace/test_distributions.py +270 -0
  81. tests/statespace/test_kalman_filter.py +326 -0
  82. tests/statespace/test_representation.py +175 -0
  83. tests/statespace/test_statespace.py +818 -0
  84. tests/statespace/test_statespace_JAX.py +156 -0
  85. tests/statespace/test_structural.py +829 -0
  86. tests/statespace/utilities/__init__.py +0 -0
  87. tests/statespace/utilities/shared_fixtures.py +9 -0
  88. tests/statespace/utilities/statsmodel_local_level.py +42 -0
  89. tests/statespace/utilities/test_helpers.py +310 -0
  90. tests/test_blackjax_smc.py +222 -0
  91. tests/test_find_map.py +98 -0
  92. tests/test_histogram_approximation.py +109 -0
  93. tests/test_laplace.py +238 -0
  94. tests/test_linearmodel.py +208 -0
  95. tests/test_model_builder.py +306 -0
  96. tests/test_pathfinder.py +45 -0
  97. tests/test_pivoted_cholesky.py +24 -0
  98. tests/test_printing.py +98 -0
  99. tests/test_prior_from_trace.py +172 -0
  100. tests/test_splines.py +77 -0
  101. tests/utils.py +31 -0
@@ -0,0 +1,431 @@
1
+ import logging
2
+
3
+ from collections.abc import Callable
4
+ from typing import Literal, cast, get_args
5
+
6
+ import jax
7
+ import numpy as np
8
+ import pymc as pm
9
+ import pytensor
10
+ import pytensor.tensor as pt
11
+
12
+ from better_optimize import minimize
13
+ from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
14
+ from pymc.blocking import DictToArrayBijection, RaveledVars
15
+ from pymc.initial_point import make_initial_point_fn
16
+ from pymc.model.transform.optimization import freeze_dims_and_data
17
+ from pymc.pytensorf import join_nonshared_inputs
18
+ from pymc.util import get_default_varnames
19
+ from pytensor.compile import Function
20
+ from pytensor.compile.mode import Mode
21
+ from pytensor.tensor import TensorVariable
22
+ from scipy.optimize import OptimizeResult
23
+
24
+ _log = logging.getLogger(__name__)
25
+
26
+ GradientBackend = Literal["pytensor", "jax"]
27
+ VALID_BACKENDS = get_args(GradientBackend)
28
+
29
+
30
+ def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
31
+ method_info = MINIMIZE_MODE_KWARGS[method].copy()
32
+
33
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
34
+ use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
35
+ use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
36
+
37
+ if use_hess and use_hessp:
38
+ use_hess = False
39
+
40
+ return use_grad, use_hess, use_hessp
41
+
42
+
43
+ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
44
+ """
45
+ Compute the nearest positive semi-definite matrix to a given matrix.
46
+
47
+ This function takes a square matrix and returns the nearest positive semi-definite matrix using
48
+ eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
49
+ of the Frobenius norm.
50
+
51
+ Parameters
52
+ ----------
53
+ A : np.ndarray
54
+ Input square matrix.
55
+
56
+ Returns
57
+ -------
58
+ np.ndarray
59
+ The nearest positive semi-definite matrix to the input matrix.
60
+ """
61
+ C = (A + A.T) / 2
62
+ eigval, eigvec = np.linalg.eig(C)
63
+ eigval[eigval < 0] = 0
64
+
65
+ return eigvec @ np.diag(eigval) @ eigvec.T
66
+
67
+
68
+ def _unconstrained_vector_to_constrained_rvs(model):
69
+ constrained_rvs, unconstrained_vector = join_nonshared_inputs(
70
+ model.initial_point(),
71
+ inputs=model.value_vars,
72
+ outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
73
+ )
74
+
75
+ unconstrained_vector.name = "unconstrained_vector"
76
+ return constrained_rvs, unconstrained_vector
77
+
78
+
79
+ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws):
80
+ X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0]))
81
+ out = []
82
+ for rv, idx in slices.items():
83
+ f = model.rvs_to_transforms[rv]
84
+ untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx]
85
+
86
+ if rv in out_shapes:
87
+ new_shape = (chains, draws) + out_shapes[rv]
88
+ untransformed_X = untransformed_X.reshape(new_shape)
89
+
90
+ out.append(untransformed_X)
91
+
92
+ f_untransform = pytensor.function(
93
+ inputs=[pytensor.In(X, borrow=True)],
94
+ outputs=pytensor.Out(out, borrow=True),
95
+ mode=Mode(linker="py", optimizer="FAST_COMPILE"),
96
+ )
97
+ return f_untransform(posterior_draws)
98
+
99
+
100
+ def _compile_jax_gradients(
101
+ f_loss: Function, use_hess: bool, use_hessp: bool
102
+ ) -> tuple[Callable | None, Callable | None]:
103
+ """
104
+ Compile loss function gradients using JAX.
105
+
106
+ Parameters
107
+ ----------
108
+ f_loss: Function
109
+ The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
110
+ compiled with mode="JAX".
111
+ use_hess: bool
112
+ Whether to compile a function to compute the hessian of the loss function.
113
+ use_hessp: bool
114
+ Whether to compile a function to compute the hessian-vector product of the loss function.
115
+
116
+ Returns
117
+ -------
118
+ f_loss_and_grad: Callable
119
+ The compiled loss function and gradient function.
120
+ f_hess: Callable | None
121
+ The compiled hessian function, or None if use_hess is False.
122
+ f_hessp: Callable | None
123
+ The compiled hessian-vector product function, or None if use_hessp is False.
124
+ """
125
+ f_hess = None
126
+ f_hessp = None
127
+
128
+ orig_loss_fn = f_loss.vm.jit_fn
129
+
130
+ @jax.jit
131
+ def loss_fn_jax_grad(x, *shared):
132
+ return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
133
+
134
+ f_loss_and_grad = loss_fn_jax_grad
135
+
136
+ if use_hessp:
137
+
138
+ def f_hessp_jax(x, p):
139
+ y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,))
140
+ return jax.numpy.stack(u)
141
+
142
+ f_hessp = jax.jit(f_hessp_jax)
143
+
144
+ if use_hess:
145
+ _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1])
146
+
147
+ def f_hess_jax(x):
148
+ return jax.numpy.stack(_f_hess_jax(x))
149
+
150
+ f_hess = jax.jit(f_hess_jax)
151
+
152
+ return f_loss_and_grad, f_hess, f_hessp
153
+
154
+
155
+ def _compile_functions(
156
+ loss: TensorVariable,
157
+ inputs: list[TensorVariable],
158
+ compute_grad: bool,
159
+ compute_hess: bool,
160
+ compute_hessp: bool,
161
+ compile_kwargs: dict | None = None,
162
+ ) -> list[Function] | list[Function, Function | None, Function | None]:
163
+ """
164
+ Compile loss functions for use with scipy.optimize.minimize.
165
+
166
+ Parameters
167
+ ----------
168
+ loss: TensorVariable
169
+ The loss function to compile.
170
+ inputs: list[TensorVariable]
171
+ A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
172
+ expect the function signature to be f(x, *args), where x is a 1D array of parameters.
173
+ compute_grad: bool
174
+ Whether to compile a function that computes the gradients of the loss function.
175
+ compute_hess: bool
176
+ Whether to compile a function that computes the Hessian of the loss function.
177
+ compute_hessp: bool
178
+ Whether to compile a function that computes the Hessian-vector product of the loss function.
179
+ compile_kwargs: dict, optional
180
+ Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
181
+
182
+ Returns
183
+ -------
184
+ f_loss: Function
185
+
186
+ f_hess: Function | None
187
+ f_hessp: Function | None
188
+ """
189
+ loss = pm.pytensorf.rewrite_pregrad(loss)
190
+ f_hess = None
191
+ f_hessp = None
192
+
193
+ if compute_grad:
194
+ grads = pytensor.gradient.grad(loss, inputs)
195
+ grad = pt.concatenate([grad.ravel() for grad in grads])
196
+ f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
197
+ else:
198
+ f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
199
+ return [f_loss]
200
+
201
+ if compute_hess:
202
+ hess = pytensor.gradient.jacobian(grad, inputs)[0]
203
+ f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
204
+
205
+ if compute_hessp:
206
+ p = pt.tensor("p", shape=inputs[0].type.shape)
207
+ hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208
+ f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
209
+
210
+ return [f_loss_and_grad, f_hess, f_hessp]
211
+
212
+
213
+ def scipy_optimize_funcs_from_loss(
214
+ loss: TensorVariable,
215
+ inputs: list[TensorVariable],
216
+ initial_point_dict: dict[str, np.ndarray | float | int],
217
+ use_grad: bool,
218
+ use_hess: bool,
219
+ use_hessp: bool,
220
+ gradient_backend: GradientBackend = "pytensor",
221
+ compile_kwargs: dict | None = None,
222
+ ) -> tuple[Callable, ...]:
223
+ """
224
+ Compile loss functions for use with scipy.optimize.minimize.
225
+
226
+ Parameters
227
+ ----------
228
+ loss: TensorVariable
229
+ The loss function to compile.
230
+ inputs: list[TensorVariable]
231
+ The input variables to the loss function.
232
+ initial_point_dict: dict[str, np.ndarray | float | int]
233
+ Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
234
+ use_grad: bool
235
+ Whether to compile a function that computes the gradients of the loss function.
236
+ use_hess: bool
237
+ Whether to compile a function that computes the Hessian of the loss function.
238
+ use_hessp: bool
239
+ Whether to compile a function that computes the Hessian-vector product of the loss function.
240
+ gradient_backend: str, default "pytensor"
241
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
+ compile_kwargs:
243
+ Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
244
+
245
+ Returns
246
+ -------
247
+ f_loss: Callable
248
+ The compiled loss function.
249
+ f_hess: Callable | None
250
+ The compiled hessian function, or None if use_hess is False.
251
+ f_hessp: Callable | None
252
+ The compiled hessian-vector product function, or None if use_hessp is False.
253
+ """
254
+
255
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
256
+
257
+ if (use_hess or use_hessp) and not use_grad:
258
+ raise ValueError(
259
+ "Cannot compute hessian or hessian-vector product without also computing the gradient"
260
+ )
261
+
262
+ if gradient_backend not in VALID_BACKENDS:
263
+ raise ValueError(
264
+ f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
265
+ )
266
+
267
+ use_jax_gradients = (gradient_backend == "jax") and use_grad
268
+
269
+ mode = compile_kwargs.get("mode", None)
270
+ if mode is None and use_jax_gradients:
271
+ compile_kwargs["mode"] = "JAX"
272
+ elif mode != "JAX" and use_jax_gradients:
273
+ raise ValueError(
274
+ 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
275
+ )
276
+
277
+ if not isinstance(inputs, list):
278
+ inputs = [inputs]
279
+
280
+ [loss], flat_input = join_nonshared_inputs(
281
+ point=initial_point_dict, outputs=[loss], inputs=inputs
282
+ )
283
+
284
+ compute_grad = use_grad and not use_jax_gradients
285
+ compute_hess = use_hess and not use_jax_gradients
286
+ compute_hessp = use_hessp and not use_jax_gradients
287
+
288
+ funcs = _compile_functions(
289
+ loss=loss,
290
+ inputs=[flat_input],
291
+ compute_grad=compute_grad,
292
+ compute_hess=compute_hess,
293
+ compute_hessp=compute_hessp,
294
+ compile_kwargs=compile_kwargs,
295
+ )
296
+
297
+ # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
298
+ f_loss = funcs.pop(0)
299
+ f_hess = funcs.pop(0) if compute_grad else None
300
+ f_hessp = funcs.pop(0) if compute_grad else None
301
+
302
+ if use_jax_gradients:
303
+ # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
+ f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
305
+
306
+ return f_loss, f_hess, f_hessp
307
+
308
+
309
+ def find_MAP(
310
+ method: minimize_method,
311
+ *,
312
+ model: pm.Model | None = None,
313
+ use_grad: bool | None = None,
314
+ use_hessp: bool | None = None,
315
+ use_hess: bool | None = None,
316
+ initvals: dict | None = None,
317
+ random_seed: int | np.random.Generator | None = None,
318
+ return_raw: bool = False,
319
+ jitter_rvs: list[TensorVariable] | None = None,
320
+ progressbar: bool = True,
321
+ include_transformed: bool = True,
322
+ gradient_backend: GradientBackend = "pytensor",
323
+ compile_kwargs: dict | None = None,
324
+ **optimizer_kwargs,
325
+ ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
326
+ """
327
+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
328
+
329
+ Parameters
330
+ ----------
331
+ model : pm.Model
332
+ The PyMC model to be fit. If None, the current model context is used.
333
+ method : str
334
+ The optimization method to use. See scipy.optimize.minimize documentation for details.
335
+ use_grad : bool | None, optional
336
+ Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
337
+ the ``method``.
338
+ use_hessp : bool | None, optional
339
+ Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
340
+ the ``method``.
341
+ use_hess : bool | None, optional
342
+ Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
343
+ the ``method``.
344
+ initvals : None | dict, optional
345
+ Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
346
+ If None, the model's default initial values are used.
347
+ random_seed : None | int | np.random.Generator, optional
348
+ Seed for the random number generator or a numpy Generator for reproducibility
349
+ return_raw: bool | False, optinal
350
+ Whether to also return the full output of `scipy.optimize.minimize`
351
+ jitter_rvs : list of TensorVariables, optional
352
+ Variables whose initial values should be jittered. If None, all variables are jittered.
353
+ progressbar : bool, optional
354
+ Whether to display a progress bar during optimization. Defaults to True.
355
+ include_transformed: bool, optional
356
+ Whether to include transformed variable values in the returned dictionary. Defaults to True.
357
+ gradient_backend: str, default "pytensor"
358
+ Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
359
+ compile_kwargs: dict, optional
360
+ Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
361
+ **optimizer_kwargs
362
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
363
+
364
+ Returns
365
+ -------
366
+ optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult]
367
+ Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True,
368
+ also returns the object returned by ``scipy.optimize.minimize``.
369
+ """
370
+ model = pm.modelcontext(model)
371
+ frozen_model = freeze_dims_and_data(model)
372
+
373
+ jitter_rvs = [] if jitter_rvs is None else jitter_rvs
374
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
375
+
376
+ ipfn = make_initial_point_fn(
377
+ model=frozen_model,
378
+ jitter_rvs=set(jitter_rvs),
379
+ return_transformed=True,
380
+ overrides=initvals,
381
+ )
382
+
383
+ start_dict = ipfn(random_seed)
384
+ vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
385
+ initial_params = DictToArrayBijection.map(
386
+ {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
387
+ )
388
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
389
+ method, use_grad, use_hess, use_hessp
390
+ )
391
+
392
+ f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
393
+ loss=-frozen_model.logp(jacobian=False),
394
+ inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
395
+ initial_point_dict=start_dict,
396
+ use_grad=use_grad,
397
+ use_hess=use_hess,
398
+ use_hessp=use_hessp,
399
+ gradient_backend=gradient_backend,
400
+ compile_kwargs=compile_kwargs,
401
+ )
402
+
403
+ args = optimizer_kwargs.pop("args", None)
404
+
405
+ # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
406
+ # if so. That is why it is not set here, regardless of user settings.
407
+ optimizer_result = minimize(
408
+ f=f_logp,
409
+ x0=cast(np.ndarray[float], initial_params.data),
410
+ args=args,
411
+ hess=f_hess,
412
+ hessp=f_hessp,
413
+ progressbar=progressbar,
414
+ method=method,
415
+ **optimizer_kwargs,
416
+ )
417
+
418
+ raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
419
+ unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
420
+ unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
421
+ DictToArrayBijection.rmap(raveled_optimized)
422
+ )
423
+
424
+ optimized_point = {
425
+ var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
426
+ }
427
+
428
+ if return_raw:
429
+ return optimized_point, optimizer_result
430
+
431
+ return optimized_point
@@ -0,0 +1,44 @@
1
+ # Copyright 2022 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from importlib.util import find_spec
15
+
16
+
17
+ def fit(method, **kwargs):
18
+ """
19
+ Fit a model with an inference algorithm
20
+
21
+ Parameters
22
+ ----------
23
+ method : str
24
+ Which inference method to run.
25
+ Supported: pathfinder or laplace
26
+
27
+ kwargs are passed on.
28
+
29
+ Returns
30
+ -------
31
+ arviz.InferenceData
32
+ """
33
+ if method == "pathfinder":
34
+ if find_spec("blackjax") is None:
35
+ raise RuntimeError("Need BlackJAX to use `pathfinder`")
36
+
37
+ from pymc_extras.inference.pathfinder import fit_pathfinder
38
+
39
+ return fit_pathfinder(**kwargs)
40
+
41
+ if method == "laplace":
42
+ from pymc_extras.inference.laplace import fit_laplace
43
+
44
+ return fit_laplace(**kwargs)