pymc-extras 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,224 @@
1
+ """Deserialize dictionaries into Python objects.
2
+
3
+ This is a two step process:
4
+
5
+ 1. Determine if the data is of the correct type.
6
+ 2. Deserialize the data into a python object.
7
+
8
+ Examples
9
+ --------
10
+ Make use of the already registered deserializers:
11
+
12
+ .. code-block:: python
13
+
14
+ from pymc_extras.deserialize import deserialize
15
+
16
+ prior_class_data = {
17
+ "dist": "Normal",
18
+ "kwargs": {"mu": 0, "sigma": 1}
19
+ }
20
+ prior = deserialize(prior_class_data)
21
+ # Prior("Normal", mu=0, sigma=1)
22
+
23
+ Register custom class deserialization:
24
+
25
+ .. code-block:: python
26
+
27
+ from pymc_extras.deserialize import register_deserialization
28
+
29
+ class MyClass:
30
+ def __init__(self, value: int):
31
+ self.value = value
32
+
33
+ def to_dict(self) -> dict:
34
+ # Example of what the to_dict method might look like.
35
+ return {"value": self.value}
36
+
37
+ register_deserialization(
38
+ is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
+ deserialize=lambda data: MyClass(value=data["value"]),
40
+ )
41
+
42
+ Deserialize data into that custom class:
43
+
44
+ .. code-block:: python
45
+
46
+ from pymc_extras.deserialize import deserialize
47
+
48
+ data = {"value": 42}
49
+ obj = deserialize(data)
50
+ assert isinstance(obj, MyClass)
51
+
52
+
53
+ """
54
+
55
+ from collections.abc import Callable
56
+ from dataclasses import dataclass
57
+ from typing import Any
58
+
59
+ IsType = Callable[[Any], bool]
60
+ Deserialize = Callable[[Any], Any]
61
+
62
+
63
+ @dataclass
64
+ class Deserializer:
65
+ """Object to store information required for deserialization.
66
+
67
+ All deserializers should be stored via the :func:`register_deserialization` function
68
+ instead of creating this object directly.
69
+
70
+ Attributes
71
+ ----------
72
+ is_type : IsType
73
+ Function to determine if the data is of the correct type.
74
+ deserialize : Deserialize
75
+ Function to deserialize the data.
76
+
77
+ Examples
78
+ --------
79
+ .. code-block:: python
80
+
81
+ from typing import Any
82
+
83
+ class MyClass:
84
+ def __init__(self, value: int):
85
+ self.value = value
86
+
87
+ from pymc_extras.deserialize import Deserializer
88
+
89
+ def is_type(data: Any) -> bool:
90
+ return data.keys() == {"value"} and isinstance(data["value"], int)
91
+
92
+ def deserialize(data: dict) -> MyClass:
93
+ return MyClass(value=data["value"])
94
+
95
+ deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
+
97
+ """
98
+
99
+ is_type: IsType
100
+ deserialize: Deserialize
101
+
102
+
103
+ DESERIALIZERS: list[Deserializer] = []
104
+
105
+
106
+ class DeserializableError(Exception):
107
+ """Error raised when data cannot be deserialized."""
108
+
109
+ def __init__(self, data: Any):
110
+ self.data = data
111
+ super().__init__(
112
+ f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping."
113
+ )
114
+
115
+
116
+ def deserialize(data: Any) -> Any:
117
+ """Deserialize a dictionary into a Python object.
118
+
119
+ Use the :func:`register_deserialization` function to add custom deserializations.
120
+
121
+ Deserialization is a two step process due to the dynamic nature of the data:
122
+
123
+ 1. Determine if the data is of the correct type.
124
+ 2. Deserialize the data into a Python object.
125
+
126
+ Each registered deserialization is checked in order until one is found that can
127
+ deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised.
128
+
129
+ A :class:`DeserializableError` is raised when the data fails to be deserialized
130
+ by any of the registered deserializers.
131
+
132
+ Parameters
133
+ ----------
134
+ data : Any
135
+ The data to deserialize.
136
+
137
+ Returns
138
+ -------
139
+ Any
140
+ The deserialized object.
141
+
142
+ Raises
143
+ ------
144
+ DeserializableError
145
+ Raised when the data doesn't match any registered deserializations
146
+ or fails to be deserialized.
147
+
148
+ Examples
149
+ --------
150
+ Deserialize a :class:`pymc_extras.prior.Prior` object:
151
+
152
+ .. code-block:: python
153
+
154
+ from pymc_extras.deserialize import deserialize
155
+
156
+ data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
157
+ prior = deserialize(data)
158
+ # Prior("Normal", mu=0, sigma=1)
159
+
160
+ """
161
+ for mapping in DESERIALIZERS:
162
+ try:
163
+ is_type = mapping.is_type(data)
164
+ except Exception:
165
+ is_type = False
166
+
167
+ if not is_type:
168
+ continue
169
+
170
+ try:
171
+ return mapping.deserialize(data)
172
+ except Exception as e:
173
+ raise DeserializableError(data) from e
174
+ else:
175
+ raise DeserializableError(data)
176
+
177
+
178
+ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
179
+ """Register an arbitrary deserialization.
180
+
181
+ Use the :func:`deserialize` function to then deserialize data using all registered
182
+ deserialize functions.
183
+
184
+ Parameters
185
+ ----------
186
+ is_type : Callable[[Any], bool]
187
+ Function to determine if the data is of the correct type.
188
+ deserialize : Callable[[dict], Any]
189
+ Function to deserialize the data of that type.
190
+
191
+ Examples
192
+ --------
193
+ Register a custom class deserialization:
194
+
195
+ .. code-block:: python
196
+
197
+ from pymc_extras.deserialize import register_deserialization
198
+
199
+ class MyClass:
200
+ def __init__(self, value: int):
201
+ self.value = value
202
+
203
+ def to_dict(self) -> dict:
204
+ # Example of what the to_dict method might look like.
205
+ return {"value": self.value}
206
+
207
+ register_deserialization(
208
+ is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
+ deserialize=lambda data: MyClass(value=data["value"]),
210
+ )
211
+
212
+ Use that custom class deserialization:
213
+
214
+ .. code-block:: python
215
+
216
+ from pymc_extras.deserialize import deserialize
217
+
218
+ data = {"value": 42}
219
+ obj = deserialize(data)
220
+ assert isinstance(obj, MyClass)
221
+
222
+ """
223
+ mapping = Deserializer(is_type=is_type, deserialize=deserialize)
224
+ DESERIALIZERS.append(mapping)
@@ -12,9 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from pymc_extras.inference.find_map import find_MAP
16
15
  from pymc_extras.inference.fit import fit
17
- from pymc_extras.inference.laplace import fit_laplace
16
+ from pymc_extras.inference.laplace_approx.find_map import find_MAP
17
+ from pymc_extras.inference.laplace_approx.laplace import fit_laplace
18
18
  from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19
19
 
20
20
  __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
@@ -37,6 +37,6 @@ def fit(method: str, **kwargs) -> az.InferenceData:
37
37
  return fit_pathfinder(**kwargs)
38
38
 
39
39
  if method == "laplace":
40
- from pymc_extras.inference.laplace import fit_laplace
40
+ from pymc_extras.inference import fit_laplace
41
41
 
42
42
  return fit_laplace(**kwargs)
File without changes
@@ -0,0 +1,347 @@
1
+ import logging
2
+
3
+ from collections.abc import Callable
4
+ from typing import Literal, cast
5
+
6
+ import numpy as np
7
+ import pymc as pm
8
+
9
+ from better_optimize import basinhopping, minimize
10
+ from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
11
+ from pymc.blocking import DictToArrayBijection, RaveledVars
12
+ from pymc.initial_point import make_initial_point_fn
13
+ from pymc.model.transform.optimization import freeze_dims_and_data
14
+ from pymc.util import get_default_varnames
15
+ from pytensor.tensor import TensorVariable
16
+ from scipy.optimize import OptimizeResult
17
+
18
+ from pymc_extras.inference.laplace_approx.idata import (
19
+ add_data_to_inference_data,
20
+ add_fit_to_inference_data,
21
+ add_optimizer_result_to_inference_data,
22
+ map_results_to_inference_data,
23
+ )
24
+ from pymc_extras.inference.laplace_approx.scipy_interface import (
25
+ GradientBackend,
26
+ scipy_optimize_funcs_from_loss,
27
+ )
28
+
29
+ _log = logging.getLogger(__name__)
30
+
31
+
32
+ def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33
+ method_info = MINIMIZE_MODE_KWARGS[method].copy()
34
+
35
+ if use_hess and use_hessp:
36
+ _log.warning(
37
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39
+ 'Setting "use_hess" to False.'
40
+ )
41
+ use_hess = False
42
+
43
+ use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44
+
45
+ if use_hessp is not None and use_hess is None:
46
+ use_hess = not use_hessp
47
+
48
+ elif use_hess is not None and use_hessp is None:
49
+ use_hessp = not use_hess
50
+
51
+ elif use_hessp is None and use_hess is None:
52
+ use_hessp = method_info["uses_hessp"]
53
+ use_hess = method_info["uses_hess"]
54
+ if use_hessp and use_hess:
55
+ # If a method could use either hess or hessp, we default to using hessp
56
+ use_hess = False
57
+
58
+ return use_grad, use_hess, use_hessp
59
+
60
+
61
+ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
62
+ """
63
+ Compute the nearest positive semi-definite matrix to a given matrix.
64
+
65
+ This function takes a square matrix and returns the nearest positive semi-definite matrix using
66
+ eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
67
+ of the Frobenius norm.
68
+
69
+ Parameters
70
+ ----------
71
+ A : np.ndarray
72
+ Input square matrix.
73
+
74
+ Returns
75
+ -------
76
+ np.ndarray
77
+ The nearest positive semi-definite matrix to the input matrix.
78
+ """
79
+ C = (A + A.T) / 2
80
+ eigval, eigvec = np.linalg.eigh(C)
81
+ eigval[eigval < 0] = 0
82
+
83
+ return eigvec @ np.diag(eigval) @ eigvec.T
84
+
85
+
86
+ def _make_initial_point(model, initvals=None, random_seed=None, jitter_rvs=None):
87
+ jitter_rvs = [] if jitter_rvs is None else jitter_rvs
88
+
89
+ ipfn = make_initial_point_fn(
90
+ model=model,
91
+ jitter_rvs=set(jitter_rvs),
92
+ return_transformed=True,
93
+ overrides=initvals,
94
+ )
95
+
96
+ start_dict = ipfn(random_seed)
97
+ vars_dict = {var.name: var for var in model.continuous_value_vars}
98
+ initial_params = DictToArrayBijection.map(
99
+ {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
100
+ )
101
+
102
+ return initial_params
103
+
104
+
105
+ def _compute_inverse_hessian(
106
+ optimizer_result: OptimizeResult | None,
107
+ optimal_point: np.ndarray | None,
108
+ f_fused: Callable | None,
109
+ f_hessp: Callable | None,
110
+ use_hess: bool,
111
+ method: minimize_method | Literal["BFGS", "L-BFGS-B"],
112
+ ):
113
+ """
114
+ Compute the Hessian matrix or its inverse based on the optimization result and the method used.
115
+
116
+ Downstream functions (e.g. laplace approximation) will need the inverse Hessian matrix. This function computes it
117
+ in the cheapest way possible, depending on the optimization method used and the available compiled functions.
118
+
119
+ Parameters
120
+ ----------
121
+ optimizer_result: OptimizeResult, optional
122
+ The result of the optimization, containing the optimized parameters and possibly an approximate inverse Hessian.
123
+ optimal_point: np.ndarray, optional
124
+ The optimal point found by the optimizer, used to compute the Hessian if necessary. If not provided, it will be
125
+ extracted from the optimizer result.
126
+ f_fused: callable, optional
127
+ The compiled function representing the loss and possibly its gradient and Hessian.
128
+ f_hessp: callable, optional
129
+ The compiled function for Hessian-vector products, if available.
130
+ use_hess: bool
131
+ Whether the Hessian matrix was used in the optimization.
132
+ method: minimize_method
133
+ The optimization method used, which determines how the Hessian is computed.
134
+
135
+ Returns
136
+ -------
137
+ H_inv: np.ndarray
138
+ The inverse Hessian matrix, computed based on the optimization method and available functions.
139
+ """
140
+ if optimal_point is None and optimizer_result is None:
141
+ raise ValueError("At least one of `optimal_point` or `optimizer_result` must be provided.")
142
+
143
+ x_star = optimizer_result.x if optimizer_result is not None else optimal_point
144
+ n_vars = len(x_star)
145
+
146
+ if method == "BFGS" and optimizer_result is not None:
147
+ # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than
148
+ # re-computing something
149
+ if hasattr(optimizer_result, "lowest_optimization_result"):
150
+ # We did basinhopping, need to get the inner optimizer results
151
+ H_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None)
152
+ else:
153
+ H_inv = getattr(optimizer_result, "hess_inv", None)
154
+
155
+ elif method == "L-BFGS-B" and optimizer_result is not None:
156
+ # Here we will have a LinearOperator representing the inverse Hessian-Vector product.
157
+ if hasattr(optimizer_result, "lowest_optimization_result"):
158
+ # We did basinhopping, need to get the inner optimizer results
159
+ f_hessp_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None)
160
+ else:
161
+ f_hessp_inv = getattr(optimizer_result, "hess_inv", None)
162
+
163
+ if f_hessp_inv is not None:
164
+ basis = np.eye(n_vars)
165
+ H_inv = np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1)
166
+ else:
167
+ H_inv = None
168
+
169
+ elif f_hessp is not None:
170
+ # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from
171
+ # the hessp function, using euclidian basis vector.
172
+ basis = np.eye(n_vars)
173
+ H = np.stack([f_hessp(x_star, basis[:, i]) for i in range(n_vars)], axis=-1)
174
+ H_inv = np.linalg.inv(get_nearest_psd(H))
175
+
176
+ elif use_hess and f_fused is not None:
177
+ # If we compiled a hessian function, just use it
178
+ _, _, H = f_fused(x_star)
179
+ H_inv = np.linalg.inv(get_nearest_psd(H))
180
+
181
+ else:
182
+ H_inv = None
183
+
184
+ return H_inv
185
+
186
+
187
+ def find_MAP(
188
+ method: minimize_method | Literal["basinhopping"] = "L-BFGS-B",
189
+ *,
190
+ model: pm.Model | None = None,
191
+ use_grad: bool | None = None,
192
+ use_hessp: bool | None = None,
193
+ use_hess: bool | None = None,
194
+ initvals: dict | None = None,
195
+ random_seed: int | np.random.Generator | None = None,
196
+ jitter_rvs: list[TensorVariable] | None = None,
197
+ progressbar: bool = True,
198
+ include_transformed: bool = True,
199
+ gradient_backend: GradientBackend = "pytensor",
200
+ compile_kwargs: dict | None = None,
201
+ **optimizer_kwargs,
202
+ ) -> (
203
+ dict[str, np.ndarray]
204
+ | tuple[dict[str, np.ndarray], np.ndarray]
205
+ | tuple[dict[str, np.ndarray], OptimizeResult]
206
+ | tuple[dict[str, np.ndarray], OptimizeResult, np.ndarray]
207
+ ):
208
+ """
209
+ Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
210
+
211
+ Parameters
212
+ ----------
213
+ model : pm.Model
214
+ The PyMC model to be fit. If None, the current model context is used.
215
+ method : str
216
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
217
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
218
+
219
+ See scipy.optimize.minimize documentation for details.
220
+ use_grad : bool | None, optional
221
+ Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
222
+ the ``method``.
223
+ use_hessp : bool | None, optional
224
+ Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
225
+ the ``method``.
226
+ use_hess : bool | None, optional
227
+ Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
228
+ the ``method``.
229
+ initvals : None | dict, optional
230
+ Initial values for the model parameters, as str:ndarray key-value pairs. Partial initialization is permitted.
231
+ If None, the model's default initial values are used.
232
+ random_seed : None | int | np.random.Generator, optional
233
+ Seed for the random number generator or a numpy Generator for reproducibility
234
+ jitter_rvs : list of TensorVariables, optional
235
+ Variables whose initial values should be jittered. If None, all variables are jittered.
236
+ progressbar : bool, optional
237
+ Whether to display a progress bar during optimization. Defaults to True.
238
+ include_transformed: bool, optional
239
+ Whether to include transformed variable values in the returned dictionary. Defaults to True.
240
+ gradient_backend: str, default "pytensor"
241
+ Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
242
+ compile_kwargs: dict, optional
243
+ Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
244
+ **optimizer_kwargs
245
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
246
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
247
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
248
+
249
+ Returns
250
+ -------
251
+ map_result: az.InferenceData
252
+ Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
253
+ latent variables, and optimizer results.
254
+ """
255
+ model = pm.modelcontext(model) if model is None else model
256
+ frozen_model = freeze_dims_and_data(model)
257
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
258
+
259
+ initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
260
+
261
+ do_basinhopping = method == "basinhopping"
262
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
263
+
264
+ if do_basinhopping:
265
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
266
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
267
+ # if one isn't provided.
268
+
269
+ method = minimizer_kwargs.pop("method", "L-BFGS-B")
270
+ minimizer_kwargs["method"] = method
271
+
272
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
273
+ method, use_grad, use_hess, use_hessp
274
+ )
275
+
276
+ f_fused, f_hessp = scipy_optimize_funcs_from_loss(
277
+ loss=-frozen_model.logp(),
278
+ inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
279
+ initial_point_dict=DictToArrayBijection.rmap(initial_params),
280
+ use_grad=use_grad,
281
+ use_hess=use_hess,
282
+ use_hessp=use_hessp,
283
+ gradient_backend=gradient_backend,
284
+ compile_kwargs=compile_kwargs,
285
+ )
286
+
287
+ args = optimizer_kwargs.pop("args", ())
288
+
289
+ # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
290
+ # if so. That is why the jac argument is not passed here in either branch.
291
+
292
+ if do_basinhopping:
293
+ if "args" not in minimizer_kwargs:
294
+ minimizer_kwargs["args"] = args
295
+ if "hessp" not in minimizer_kwargs:
296
+ minimizer_kwargs["hessp"] = f_hessp
297
+ if "method" not in minimizer_kwargs:
298
+ minimizer_kwargs["method"] = method
299
+
300
+ optimizer_result = basinhopping(
301
+ func=f_fused,
302
+ x0=cast(np.ndarray[float], initial_params.data),
303
+ progressbar=progressbar,
304
+ minimizer_kwargs=minimizer_kwargs,
305
+ **optimizer_kwargs,
306
+ )
307
+
308
+ else:
309
+ optimizer_result = minimize(
310
+ f=f_fused,
311
+ x0=cast(np.ndarray[float], initial_params.data),
312
+ args=args,
313
+ hessp=f_hessp,
314
+ progressbar=progressbar,
315
+ method=method,
316
+ **optimizer_kwargs,
317
+ )
318
+
319
+ H_inv = _compute_inverse_hessian(
320
+ optimizer_result=optimizer_result,
321
+ optimal_point=None,
322
+ f_fused=f_fused,
323
+ f_hessp=f_hessp,
324
+ use_hess=use_hess,
325
+ method=method,
326
+ )
327
+
328
+ raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329
+ unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
330
+ unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
331
+ DictToArrayBijection.rmap(raveled_optimized)
332
+ )
333
+
334
+ optimized_point = {
335
+ var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
336
+ }
337
+
338
+ idata = map_results_to_inference_data(optimized_point, frozen_model)
339
+ idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
340
+ idata = add_optimizer_result_to_inference_data(
341
+ idata, optimizer_result, method, raveled_optimized, model
342
+ )
343
+ idata = add_data_to_inference_data(
344
+ idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
345
+ )
346
+
347
+ return idata