pymc-extras 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pymc_extras/__init__.py +6 -4
  2. pymc_extras/distributions/__init__.py +2 -0
  3. pymc_extras/distributions/continuous.py +3 -2
  4. pymc_extras/distributions/discrete.py +3 -1
  5. pymc_extras/distributions/transforms/__init__.py +3 -0
  6. pymc_extras/distributions/transforms/partial_order.py +227 -0
  7. pymc_extras/inference/__init__.py +4 -2
  8. pymc_extras/inference/find_map.py +62 -17
  9. pymc_extras/inference/fit.py +6 -4
  10. pymc_extras/inference/laplace.py +14 -8
  11. pymc_extras/inference/pathfinder/lbfgs.py +49 -13
  12. pymc_extras/inference/pathfinder/pathfinder.py +89 -103
  13. pymc_extras/statespace/core/statespace.py +191 -52
  14. pymc_extras/statespace/filters/distributions.py +15 -16
  15. pymc_extras/statespace/filters/kalman_filter.py +1 -18
  16. pymc_extras/statespace/filters/kalman_smoother.py +2 -6
  17. pymc_extras/statespace/models/ETS.py +10 -0
  18. pymc_extras/statespace/models/SARIMAX.py +26 -5
  19. pymc_extras/statespace/models/VARMAX.py +12 -2
  20. pymc_extras/statespace/models/structural.py +18 -5
  21. pymc_extras/statespace/utils/data_tools.py +24 -9
  22. pymc_extras-0.2.6.dist-info/METADATA +318 -0
  23. pymc_extras-0.2.6.dist-info/RECORD +65 -0
  24. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info}/WHEEL +1 -2
  25. pymc_extras/version.py +0 -11
  26. pymc_extras/version.txt +0 -1
  27. pymc_extras-0.2.4.dist-info/METADATA +0 -110
  28. pymc_extras-0.2.4.dist-info/RECORD +0 -105
  29. pymc_extras-0.2.4.dist-info/top_level.txt +0 -2
  30. tests/__init__.py +0 -13
  31. tests/distributions/__init__.py +0 -19
  32. tests/distributions/test_continuous.py +0 -185
  33. tests/distributions/test_discrete.py +0 -210
  34. tests/distributions/test_discrete_markov_chain.py +0 -258
  35. tests/distributions/test_multivariate.py +0 -304
  36. tests/model/__init__.py +0 -0
  37. tests/model/marginal/__init__.py +0 -0
  38. tests/model/marginal/test_distributions.py +0 -132
  39. tests/model/marginal/test_graph_analysis.py +0 -182
  40. tests/model/marginal/test_marginal_model.py +0 -967
  41. tests/model/test_model_api.py +0 -38
  42. tests/statespace/__init__.py +0 -0
  43. tests/statespace/test_ETS.py +0 -411
  44. tests/statespace/test_SARIMAX.py +0 -405
  45. tests/statespace/test_VARMAX.py +0 -184
  46. tests/statespace/test_coord_assignment.py +0 -116
  47. tests/statespace/test_distributions.py +0 -270
  48. tests/statespace/test_kalman_filter.py +0 -326
  49. tests/statespace/test_representation.py +0 -175
  50. tests/statespace/test_statespace.py +0 -872
  51. tests/statespace/test_statespace_JAX.py +0 -156
  52. tests/statespace/test_structural.py +0 -836
  53. tests/statespace/utilities/__init__.py +0 -0
  54. tests/statespace/utilities/shared_fixtures.py +0 -9
  55. tests/statespace/utilities/statsmodel_local_level.py +0 -42
  56. tests/statespace/utilities/test_helpers.py +0 -310
  57. tests/test_blackjax_smc.py +0 -222
  58. tests/test_find_map.py +0 -103
  59. tests/test_histogram_approximation.py +0 -109
  60. tests/test_laplace.py +0 -265
  61. tests/test_linearmodel.py +0 -208
  62. tests/test_model_builder.py +0 -306
  63. tests/test_pathfinder.py +0 -203
  64. tests/test_pivoted_cholesky.py +0 -24
  65. tests/test_printing.py +0 -98
  66. tests/test_prior_from_trace.py +0 -172
  67. tests/test_splines.py +0 -77
  68. tests/utils.py +0 -0
  69. {pymc_extras-0.2.4.dist-info → pymc_extras-0.2.6.dist-info/licenses}/LICENSE +0 -0
pymc_extras/__init__.py CHANGED
@@ -13,18 +13,17 @@
13
13
  # limitations under the License.
14
14
  import logging
15
15
 
16
+ from importlib.metadata import version
17
+
16
18
  from pymc_extras import gp, statespace, utils
17
19
  from pymc_extras.distributions import *
18
- from pymc_extras.inference.find_map import find_MAP
19
- from pymc_extras.inference.fit import fit
20
- from pymc_extras.inference.laplace import fit_laplace
20
+ from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
21
21
  from pymc_extras.model.marginal.marginal_model import (
22
22
  MarginalModel,
23
23
  marginalize,
24
24
  recover_marginals,
25
25
  )
26
26
  from pymc_extras.model.model_api import as_model
27
- from pymc_extras.version import __version__
28
27
 
29
28
  _log = logging.getLogger("pmx")
30
29
 
@@ -33,3 +32,6 @@ if not logging.root.handlers:
33
32
  if len(_log.handlers) == 0:
34
33
  handler = logging.StreamHandler()
35
34
  _log.addHandler(handler)
35
+
36
+
37
+ __version__ = version("pymc-extras")
@@ -26,6 +26,7 @@ from pymc_extras.distributions.discrete import (
26
26
  from pymc_extras.distributions.histogram_utils import histogram_approximation
27
27
  from pymc_extras.distributions.multivariate import R2D2M2CP
28
28
  from pymc_extras.distributions.timeseries import DiscreteMarkovChain
29
+ from pymc_extras.distributions.transforms import PartialOrder
29
30
 
30
31
  __all__ = [
31
32
  "Chi",
@@ -37,4 +38,5 @@ __all__ = [
37
38
  "R2D2M2CP",
38
39
  "Skellam",
39
40
  "histogram_approximation",
41
+ "PartialOrder",
40
42
  ]
@@ -81,7 +81,7 @@ class GenExtreme(Continuous):
81
81
 
82
82
  \left\{x: 1 + \xi\left(\frac{x-\mu}{\sigma}\right) > 0 \right\}.
83
83
 
84
- Note that this parametrization is per Coles (2001), and differs from that of
84
+ Note that this parametrization is per Coles (2001) [1]_, and differs from that of
85
85
  Scipy in the sign of the shape parameter, :math:`\xi`.
86
86
 
87
87
  .. plot::
@@ -132,7 +132,7 @@ class GenExtreme(Continuous):
132
132
 
133
133
  References
134
134
  ----------
135
- .. [Coles2001] Coles, S.G. (2001).
135
+ .. [1] Coles, S.G. (2001).
136
136
  An Introduction to the Statistical Modeling of Extreme Values
137
137
  Springer-Verlag, London
138
138
 
@@ -260,6 +260,7 @@ class Chi:
260
260
  Examples
261
261
  --------
262
262
  .. code-block:: python
263
+
263
264
  import pymc as pm
264
265
  from pymc_extras.distributions import Chi
265
266
 
@@ -116,6 +116,7 @@ class GeneralizedPoisson(pm.distributions.Discrete):
116
116
 
117
117
  .. math:: f(x \mid \mu, \lambda) =
118
118
  \frac{\mu (\mu + \lambda x)^{x-1} e^{-\mu - \lambda x}}{x!}
119
+
119
120
  ======== ======================================
120
121
  Support :math:`x \in \mathbb{N}_0`
121
122
  Mean :math:`\frac{\mu}{1 - \lambda}`
@@ -135,9 +136,10 @@ class GeneralizedPoisson(pm.distributions.Discrete):
135
136
  When lam < 0, the mean is greater than the variance (underdispersion).
136
137
  When lam > 0, the mean is less than the variance (overdispersion).
137
138
 
139
+ The PMF is taken from [1]_ and the random generator function is adapted from [2]_.
140
+
138
141
  References
139
142
  ----------
140
- The PMF is taken from [1] and the random generator function is adapted from [2].
141
143
  .. [1] Consul, PoC, and Felix Famoye. "Generalized Poisson regression model."
142
144
  Communications in Statistics-Theory and Methods 21.1 (1992): 89-109.
143
145
  .. [2] Famoye, Felix. "Generalized Poisson random variate generation." American
@@ -0,0 +1,3 @@
1
+ from pymc_extras.distributions.transforms.partial_order import PartialOrder
2
+
3
+ __all__ = ["PartialOrder"]
@@ -0,0 +1,227 @@
1
+ # Copyright 2025 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import numpy as np
15
+ import pytensor.tensor as pt
16
+
17
+ from pymc.logprob.transforms import Transform
18
+
19
+ __all__ = ["PartialOrder"]
20
+
21
+
22
+ def dtype_minval(dtype):
23
+ """Find the minimum value for a given dtype"""
24
+ return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
25
+
26
+
27
+ def padded_where(x, to_len, padval=-1):
28
+ """A padded version of np.where"""
29
+ w = np.where(x)
30
+ return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
31
+
32
+
33
+ class PartialOrder(Transform):
34
+ """Create a PartialOrder transform
35
+
36
+ A more flexible version of the pymc ordered transform that
37
+ allows specifying a (strict) partial order on the elements.
38
+
39
+ Examples
40
+ --------
41
+ .. code:: python
42
+
43
+ import numpy as np
44
+ import pymc as pm
45
+ import pymc_extras as pmx
46
+
47
+ # Define two partial orders on 4 elements
48
+ # am[i,j] = 1 means i < j
49
+ adj_mats = np.array([
50
+ # 0 < {1, 2} < 3
51
+ [[0, 1, 1, 0],
52
+ [0, 0, 0, 1],
53
+ [0, 0, 0, 1],
54
+ [0, 0, 0, 0]],
55
+
56
+ # 1 < 0 < 3 < 2
57
+ [[0, 0, 0, 1],
58
+ [1, 0, 0, 0],
59
+ [0, 0, 0, 0],
60
+ [0, 0, 1, 0]],
61
+ ])
62
+
63
+ # Create the partial order from the adjacency matrices
64
+ po = pmx.PartialOrder(adj_mats)
65
+
66
+ with pm.Model() as model:
67
+ # Generate 3 samples from both partial orders
68
+ pm.Normal("po_vals", shape=(3,2,4), transform=po,
69
+ initval=po.initvals((3,2,4)))
70
+
71
+ idata = pm.sample()
72
+
73
+ # Verify that for first po, the zeroth element is always the smallest
74
+ assert (idata.posterior['po_vals'][:,:,:,0,0] <
75
+ idata.posterior['po_vals'][:,:,:,0,1:]).all()
76
+
77
+ # Verify that for second po, the second element is always the largest
78
+ assert (idata.posterior['po_vals'][:,:,:,1,2] >=
79
+ idata.posterior['po_vals'][:,:,:,1,:]).all()
80
+
81
+ Technical notes
82
+ ----------------
83
+ Partial order needs to be strict, i.e. without equalities.
84
+ A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
85
+ Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
86
+ where N is the number of nodes in the dag and D is the maximum
87
+ in-degree of a node in the transitive reduction.
88
+ """
89
+
90
+ name = "partial_order"
91
+
92
+ def __init__(self, adj_mat):
93
+ """
94
+ Initialize the PartialOrder transform
95
+
96
+ Parameters
97
+ ----------
98
+ adj_mat: ndarray
99
+ adjacency matrix for the DAG that generates the partial order,
100
+ where ``adj_mat[i][j] = 1`` denotes ``i < j``.
101
+ Note this also accepts multiple DAGs if RV is multidimensional
102
+ """
103
+
104
+ # Basic input checks
105
+ if adj_mat.ndim < 2:
106
+ raise ValueError("Adjacency matrix must have at least 2 dimensions")
107
+ if adj_mat.shape[-2] != adj_mat.shape[-1]:
108
+ raise ValueError("Adjacency matrix is not square")
109
+ if adj_mat.min() != 0 or adj_mat.max() != 1:
110
+ raise ValueError("Adjacency matrix must contain only 0s and 1s")
111
+
112
+ # Create index over the first ellipsis dimensions
113
+ idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
114
+
115
+ # Transitive closure using Floyd-Warshall
116
+ tc = adj_mat.astype(bool)
117
+ for k in range(tc.shape[-1]):
118
+ tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
119
+
120
+ # Check if the dag is acyclic
121
+ if np.any(tc.diagonal(axis1=-2, axis2=-1)):
122
+ raise ValueError("Partial order contains equalities")
123
+
124
+ # Transitive reduction using the closure
125
+ # This gives the minimum description of the partial order
126
+ # This is to minmax the input degree
127
+ adj_mat = tc * (1 - np.matmul(tc, tc))
128
+
129
+ # Find the maximum in-degree of the reduced dag
130
+ dag_idim = adj_mat.sum(axis=-2).max()
131
+
132
+ # Topological sort
133
+ ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
134
+ dm = adj_mat.copy()
135
+ for i in range(adj_mat.shape[1]):
136
+ assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
137
+ nind = np.argmin(dm.sum(axis=-2), axis=-1)
138
+ dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
139
+ dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
140
+ ts_inds[(*idx, i)] = nind
141
+ self.ts_inds = ts_inds
142
+
143
+ # Change the dag to adjacency lists (with -1 for NA)
144
+ dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
145
+ self.dag = np.swapaxes(dag_T, -2, -1)
146
+ self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
147
+
148
+ def initvals(self, shape=None, lower=-1, upper=1):
149
+ """
150
+ Create a set of appropriate initial values for the variable.
151
+ NB! It is important that proper initial values are used,
152
+ as only properly ordered values are in the range of the transform.
153
+
154
+ Parameters
155
+ ----------
156
+ shape: tuple, default None
157
+ shape of the initial values. If None, adj_mat[:-1] is used
158
+ lower: float, default -1
159
+ lower bound for the initial values
160
+ upper: float, default 1
161
+ upper bound for the initial values
162
+
163
+ Returns
164
+ -------
165
+ vals: ndarray
166
+ initial values for the transformed variable
167
+ """
168
+
169
+ if shape is None:
170
+ shape = self.dag.shape[:-1]
171
+
172
+ if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
173
+ raise ValueError("Shape must match the shape of the adjacency matrix")
174
+
175
+ # Create the initial values
176
+ vals = np.linspace(lower, upper, self.dag.shape[-2])
177
+ inds = np.argsort(self.ts_inds, axis=-1)
178
+ ivals = vals[inds]
179
+
180
+ # Expand the initial values to the extra dimensions
181
+ extra_dims = shape[: -len(self.dag.shape[:-1])]
182
+ ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
183
+
184
+ return ivals
185
+
186
+ def backward(self, value, *inputs):
187
+ minv = dtype_minval(value.dtype)
188
+ x = pt.concatenate(
189
+ [pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
190
+ )
191
+
192
+ # Indices to allow broadcasting the max over the last dimension
193
+ idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
194
+ idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
195
+
196
+ # Has to be done stepwise as next steps depend on previous values
197
+ # Also has to be done in topological order, hence the ts_inds
198
+ for i in range(self.dag.shape[-2]):
199
+ tsi = self.ts_inds[..., i]
200
+ if len(tsi.shape) == 0:
201
+ tsi = int(tsi) # if shape 0, it's a scalar
202
+ ni = (*idx, tsi) # i-th node in topological order
203
+ eni = (Ellipsis, *ni)
204
+ ist = self.is_start[ni]
205
+
206
+ mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
207
+ x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
208
+ return x[..., :-1]
209
+
210
+ def forward(self, value, *inputs):
211
+ y = pt.zeros_like(value)
212
+
213
+ minv = dtype_minval(value.dtype)
214
+ vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
215
+
216
+ # Indices to allow broadcasting the max over the last dimension
217
+ idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
218
+ idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
219
+
220
+ y = self.is_start * value + (1 - self.is_start) * (
221
+ pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
222
+ )
223
+
224
+ return y
225
+
226
+ def log_jac_det(self, value, *inputs):
227
+ return pt.sum(value * (1 - self.is_start), axis=-1)
@@ -12,7 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
15
+ from pymc_extras.inference.find_map import find_MAP
16
16
  from pymc_extras.inference.fit import fit
17
+ from pymc_extras.inference.laplace import fit_laplace
18
+ from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
17
19
 
18
- __all__ = ["fit"]
20
+ __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
@@ -9,7 +9,7 @@ import pymc as pm
9
9
  import pytensor
10
10
  import pytensor.tensor as pt
11
11
 
12
- from better_optimize import minimize
12
+ from better_optimize import basinhopping, minimize
13
13
  from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
14
14
  from pymc.blocking import DictToArrayBijection, RaveledVars
15
15
  from pymc.initial_point import make_initial_point_fn
@@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax(
146
146
  orig_loss_fn = f_loss.vm.jit_fn
147
147
 
148
148
  @jax.jit
149
- def loss_fn_jax_grad(x, *shared):
149
+ def loss_fn_jax_grad(x):
150
150
  return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
151
151
 
152
152
  f_loss_and_grad = loss_fn_jax_grad
@@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss(
301
301
  point=initial_point_dict, outputs=[loss], inputs=inputs
302
302
  )
303
303
 
304
+ # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
305
+ # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
306
+ # away.
307
+ if use_jax_gradients:
308
+ from pymc.sampling.jax import _replace_shared_variables
309
+
310
+ [loss] = _replace_shared_variables([loss])
311
+
304
312
  compute_grad = use_grad and not use_jax_gradients
305
313
  compute_hess = use_hess and not use_jax_gradients
306
314
  compute_hessp = use_hessp and not use_jax_gradients
@@ -327,7 +335,7 @@ def scipy_optimize_funcs_from_loss(
327
335
 
328
336
 
329
337
  def find_MAP(
330
- method: minimize_method,
338
+ method: minimize_method | Literal["basinhopping"],
331
339
  *,
332
340
  model: pm.Model | None = None,
333
341
  use_grad: bool | None = None,
@@ -344,14 +352,17 @@ def find_MAP(
344
352
  **optimizer_kwargs,
345
353
  ) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
346
354
  """
347
- Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
355
+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
348
356
 
349
357
  Parameters
350
358
  ----------
351
359
  model : pm.Model
352
360
  The PyMC model to be fit. If None, the current model context is used.
353
361
  method : str
354
- The optimization method to use. See scipy.optimize.minimize documentation for details.
362
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
363
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
364
+
365
+ See scipy.optimize.minimize documentation for details.
355
366
  use_grad : bool | None, optional
356
367
  Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
357
368
  the ``method``.
@@ -379,7 +390,9 @@ def find_MAP(
379
390
  compile_kwargs: dict, optional
380
391
  Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
381
392
  **optimizer_kwargs
382
- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
393
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
394
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
395
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
383
396
 
384
397
  Returns
385
398
  -------
@@ -405,6 +418,18 @@ def find_MAP(
405
418
  initial_params = DictToArrayBijection.map(
406
419
  {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
407
420
  )
421
+
422
+ do_basinhopping = method == "basinhopping"
423
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
424
+
425
+ if do_basinhopping:
426
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
427
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
428
+ # if one isn't provided.
429
+
430
+ method = minimizer_kwargs.pop("method", "L-BFGS-B")
431
+ minimizer_kwargs["method"] = method
432
+
408
433
  use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
409
434
  method, use_grad, use_hess, use_hessp
410
435
  )
@@ -423,17 +448,37 @@ def find_MAP(
423
448
  args = optimizer_kwargs.pop("args", None)
424
449
 
425
450
  # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
426
- # if so. That is why it is not set here, regardless of user settings.
427
- optimizer_result = minimize(
428
- f=f_logp,
429
- x0=cast(np.ndarray[float], initial_params.data),
430
- args=args,
431
- hess=f_hess,
432
- hessp=f_hessp,
433
- progressbar=progressbar,
434
- method=method,
435
- **optimizer_kwargs,
436
- )
451
+ # if so. That is why the jac argument is not passed here in either branch.
452
+
453
+ if do_basinhopping:
454
+ if "args" not in minimizer_kwargs:
455
+ minimizer_kwargs["args"] = args
456
+ if "hess" not in minimizer_kwargs:
457
+ minimizer_kwargs["hess"] = f_hess
458
+ if "hessp" not in minimizer_kwargs:
459
+ minimizer_kwargs["hessp"] = f_hessp
460
+ if "method" not in minimizer_kwargs:
461
+ minimizer_kwargs["method"] = method
462
+
463
+ optimizer_result = basinhopping(
464
+ func=f_logp,
465
+ x0=cast(np.ndarray[float], initial_params.data),
466
+ progressbar=progressbar,
467
+ minimizer_kwargs=minimizer_kwargs,
468
+ **optimizer_kwargs,
469
+ )
470
+
471
+ else:
472
+ optimizer_result = minimize(
473
+ f=f_logp,
474
+ x0=cast(np.ndarray[float], initial_params.data),
475
+ args=args,
476
+ hess=f_hess,
477
+ hessp=f_hessp,
478
+ progressbar=progressbar,
479
+ method=method,
480
+ **optimizer_kwargs,
481
+ )
437
482
 
438
483
  raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
439
484
  unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
@@ -11,11 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import arviz as az
14
15
 
15
16
 
16
- def fit(method, **kwargs):
17
+ def fit(method: str, **kwargs) -> az.InferenceData:
17
18
  """
18
- Fit a model with an inference algorithm
19
+ Fit a model with an inference algorithm.
20
+ See :func:`fit_pathfinder` and :func:`fit_laplace` for more details.
19
21
 
20
22
  Parameters
21
23
  ----------
@@ -23,11 +25,11 @@ def fit(method, **kwargs):
23
25
  Which inference method to run.
24
26
  Supported: pathfinder or laplace
25
27
 
26
- kwargs are passed on.
28
+ kwargs: keyword arguments are passed on to the inference method.
27
29
 
28
30
  Returns
29
31
  -------
30
- arviz.InferenceData
32
+ :class:`~arviz.InferenceData`
31
33
  """
32
34
  if method == "pathfinder":
33
35
  from pymc_extras.inference.pathfinder import fit_pathfinder
@@ -377,7 +377,10 @@ def sample_laplace_posterior(
377
377
  posterior_dist = stats.multivariate_normal(
378
378
  mean=mu.data, cov=H_inv, allow_singular=True, seed=rng
379
379
  )
380
+
380
381
  posterior_draws = posterior_dist.rvs(size=(chains, draws))
382
+ if mu.data.shape == (1,):
383
+ posterior_draws = np.expand_dims(posterior_draws, -1)
381
384
 
382
385
  if transform_samples:
383
386
  constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model)
@@ -413,7 +416,7 @@ def sample_laplace_posterior(
413
416
 
414
417
 
415
418
  def fit_laplace(
416
- optimize_method: minimize_method = "BFGS",
419
+ optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
417
420
  *,
418
421
  model: pm.Model | None = None,
419
422
  use_grad: bool | None = None,
@@ -446,8 +449,11 @@ def fit_laplace(
446
449
  ----------
447
450
  model : pm.Model
448
451
  The PyMC model to be fit. If None, the current model context is used.
449
- optimize_method : str
450
- The optimization method to use. See scipy.optimize.minimize documentation for details.
452
+ method : str
453
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
454
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
455
+
456
+ See scipy.optimize.minimize documentation for details.
451
457
  use_grad : bool | None, optional
452
458
  Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
453
459
  the ``method``.
@@ -497,16 +503,16 @@ def fit_laplace(
497
503
  diag_jitter: float | None
498
504
  A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite.
499
505
  If None, no jitter is added. Default is 1e-8.
500
- optimizer_kwargs: dict, optional
501
- Additional keyword arguments to pass to scipy.minimize. See the documentation for scipy.optimize.minimize for
502
- details. Arguments that are typically passed via ``options`` will be automatically extracted without the need
503
- to use a nested dictionary.
506
+ optimizer_kwargs
507
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
508
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
509
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
504
510
  compile_kwargs: dict, optional
505
511
  Additional keyword arguments to pass to pytensor.function.
506
512
 
507
513
  Returns
508
514
  -------
509
- idata: az.InferenceData
515
+ :class:`~arviz.InferenceData`
510
516
  An InferenceData object containing the approximated posterior samples.
511
517
 
512
518
  Examples