pymc-extras 0.4.1__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/__init__.py +8 -1
  7. pymc_extras/inference/dadvi/__init__.py +0 -0
  8. pymc_extras/inference/dadvi/dadvi.py +351 -0
  9. pymc_extras/inference/fit.py +5 -0
  10. pymc_extras/inference/laplace_approx/find_map.py +32 -47
  11. pymc_extras/inference/laplace_approx/idata.py +27 -6
  12. pymc_extras/inference/laplace_approx/laplace.py +24 -6
  13. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  14. pymc_extras/inference/pathfinder/idata.py +517 -0
  15. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  16. pymc_extras/model/marginal/graph_analysis.py +2 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  21. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  22. pymc_extras/statespace/filters/utilities.py +2 -5
  23. pymc_extras/statespace/models/DFM.py +834 -0
  24. pymc_extras/statespace/models/ETS.py +190 -198
  25. pymc_extras/statespace/models/SARIMAX.py +9 -21
  26. pymc_extras/statespace/models/VARMAX.py +22 -74
  27. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  28. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  29. pymc_extras/statespace/models/utilities.py +7 -0
  30. pymc_extras/statespace/utils/constants.py +3 -1
  31. pymc_extras/utils/model_equivalence.py +2 -2
  32. pymc_extras/utils/prior.py +10 -14
  33. pymc_extras/utils/spline.py +4 -10
  34. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  35. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +37 -33
  36. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  37. {pymc_extras-0.4.1.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
13
13
 
14
14
  from pymc_extras.deserialize import deserialize
15
15
 
16
- prior_class_data = {
17
- "dist": "Normal",
18
- "kwargs": {"mu": 0, "sigma": 1}
19
- }
16
+ prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
20
17
  prior = deserialize(prior_class_data)
21
18
  # Prior("Normal", mu=0, sigma=1)
22
19
 
@@ -26,6 +23,7 @@ Register custom class deserialization:
26
23
 
27
24
  from pymc_extras.deserialize import register_deserialization
28
25
 
26
+
29
27
  class MyClass:
30
28
  def __init__(self, value: int):
31
29
  self.value = value
@@ -34,6 +32,7 @@ Register custom class deserialization:
34
32
  # Example of what the to_dict method might look like.
35
33
  return {"value": self.value}
36
34
 
35
+
37
36
  register_deserialization(
38
37
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
38
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
80
79
 
81
80
  from typing import Any
82
81
 
82
+
83
83
  class MyClass:
84
84
  def __init__(self, value: int):
85
85
  self.value = value
86
86
 
87
+
87
88
  from pymc_extras.deserialize import Deserializer
88
89
 
90
+
89
91
  def is_type(data: Any) -> bool:
90
92
  return data.keys() == {"value"} and isinstance(data["value"], int)
91
93
 
94
+
92
95
  def deserialize(data: dict) -> MyClass:
93
96
  return MyClass(value=data["value"])
94
97
 
98
+
95
99
  deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
100
 
97
101
  """
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196
200
 
197
201
  from pymc_extras.deserialize import register_deserialization
198
202
 
203
+
199
204
  class MyClass:
200
205
  def __init__(self, value: int):
201
206
  self.value = value
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
204
209
  # Example of what the to_dict method might look like.
205
210
  return {"value": self.value}
206
211
 
212
+
207
213
  register_deserialization(
208
214
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
215
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -265,7 +265,7 @@ class Chi:
265
265
  from pymc_extras.distributions import Chi
266
266
 
267
267
  with pm.Model():
268
- x = Chi('x', nu=1)
268
+ x = Chi("x", nu=1)
269
269
  """
270
270
 
271
271
  @staticmethod
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130
130
  ... m = pm.Normal("m", dims="tests")
131
131
  ... s = pm.LogNormal("s", dims="tests")
132
132
  ... pot = pmx.distributions.histogram_approximation(
133
- ... "pot", pm.Normal.dist(m, s),
134
- ... observed=measurements, n_quantiles=50
133
+ ... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135
134
  ... )
136
135
 
137
136
  For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143
142
  ... m = pm.Normal("m", dims="tests")
144
143
  ... s = pm.LogNormal("s", dims="tests")
145
144
  ... pot = pmx.distributions.histogram_approximation(
146
- ... "pot", pm.Normal.dist(m, s),
147
- ... observed=measurements, n_quantiles=50, zero_inflation=True
145
+ ... "pot",
146
+ ... pm.Normal.dist(m, s),
147
+ ... observed=measurements,
148
+ ... n_quantiles=50,
149
+ ... zero_inflation=True,
148
150
  ... )
149
151
  """
150
152
  try:
@@ -305,6 +305,7 @@ def R2D2M2CP(
305
305
  import pymc_extras as pmx
306
306
  import pymc as pm
307
307
  import numpy as np
308
+
308
309
  X = np.random.randn(10, 3)
309
310
  b = np.random.randn(3)
310
311
  y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339
340
  # "c" - a must have in the relation
340
341
  variables_importance=[10, 1, 34],
341
342
  # NOTE: try both
342
- centered=True
343
+ centered=True,
343
344
  )
344
345
  # intercept prior centering should be around prior predictive mean
345
346
  intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365
366
  r2_std=0.2,
366
367
  # NOTE: if you know where a variable should go
367
368
  # if you do not know, leave as 0.5
368
- centered=False
369
+ centered=False,
369
370
  )
370
371
  # intercept prior centering should be around prior predictive mean
371
372
  intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394
395
  # if you do not know, leave as 0.5
395
396
  positive_probs=[0.8, 0.5, 0.1],
396
397
  # NOTE: try both
397
- centered=True
398
+ centered=True,
398
399
  )
399
400
  intercept = y.mean()
400
401
  obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113
113
 
114
114
  with pm.Model() as markov_chain:
115
115
  P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116
- init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117
- markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116
+ init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117
+ markov_chain = pmx.DiscreteMarkovChain(
118
+ "markov_chain", P=P, init_dist=init_dist, shape=(100,)
119
+ )
118
120
 
119
121
  """
120
122
 
@@ -12,9 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from pymc_extras.inference.dadvi.dadvi import fit_dadvi
15
16
  from pymc_extras.inference.fit import fit
16
17
  from pymc_extras.inference.laplace_approx.find_map import find_MAP
17
18
  from pymc_extras.inference.laplace_approx.laplace import fit_laplace
18
19
  from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19
20
 
20
- __all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
21
+ __all__ = [
22
+ "find_MAP",
23
+ "fit",
24
+ "fit_laplace",
25
+ "fit_pathfinder",
26
+ "fit_dadvi",
27
+ ]
File without changes
@@ -0,0 +1,351 @@
1
+ import arviz as az
2
+ import numpy as np
3
+ import pymc
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+ import xarray
7
+
8
+ from better_optimize import basinhopping, minimize
9
+ from better_optimize.constants import minimize_method
10
+ from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
+ from pymc.backends.arviz import (
12
+ PointFunc,
13
+ apply_function_over_dataset,
14
+ coords_and_dims_for_inferencedata,
15
+ )
16
+ from pymc.blocking import RaveledVars
17
+ from pymc.util import RandomSeed, get_default_varnames
18
+ from pytensor.tensor.variable import TensorVariable
19
+
20
+ from pymc_extras.inference.laplace_approx.idata import (
21
+ add_data_to_inference_data,
22
+ add_optimizer_result_to_inference_data,
23
+ )
24
+ from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
25
+ from pymc_extras.inference.laplace_approx.scipy_interface import (
26
+ scipy_optimize_funcs_from_loss,
27
+ set_optimizer_function_defaults,
28
+ )
29
+
30
+
31
+ def fit_dadvi(
32
+ model: Model | None = None,
33
+ n_fixed_draws: int = 30,
34
+ n_draws: int = 1000,
35
+ include_transformed: bool = False,
36
+ optimizer_method: minimize_method = "trust-ncg",
37
+ use_grad: bool | None = None,
38
+ use_hessp: bool | None = None,
39
+ use_hess: bool | None = None,
40
+ gradient_backend: str = "pytensor",
41
+ compile_kwargs: dict | None = None,
42
+ random_seed: RandomSeed = None,
43
+ progressbar: bool = True,
44
+ **optimizer_kwargs,
45
+ ) -> az.InferenceData:
46
+ """
47
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
48
+
49
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
50
+
51
+ Parameters
52
+ ----------
53
+ model : pm.Model
54
+ The PyMC model to be fit. If None, the current model context is used.
55
+
56
+ n_fixed_draws : int
57
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
58
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
59
+
60
+ random_seed: int
61
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
62
+ the same result.
63
+
64
+ n_draws: int
65
+ The number of draws to return from the variational approximation.
66
+
67
+ include_transformed: bool
68
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
69
+ output.
70
+
71
+ optimizer_method: str
72
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
73
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
74
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
75
+ the optimum.
76
+
77
+ gradient_backend: str
78
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
79
+
80
+ compile_kwargs: dict, optional
81
+ Additional keyword arguments to pass to `pytensor.function`
82
+
83
+ use_grad: bool, optional
84
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
85
+
86
+ use_hessp: bool, optional
87
+ If True, pass the hessian vector product to `scipy.optimize.minimize`.
88
+
89
+ use_hess: bool, optional
90
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
91
+ computation can be slow and memory-intensive if there are many parameters.
92
+
93
+ progressbar: bool
94
+ Whether or not to show a progress bar during optimization. Default is True.
95
+
96
+ optimizer_kwargs:
97
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98
+ that function for details.
99
+
100
+ Returns
101
+ -------
102
+ :class:`~arviz.InferenceData`
103
+ The inference data containing the results of the DADVI algorithm.
104
+
105
+ References
106
+ ----------
107
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
108
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
109
+ """
110
+
111
+ model = pymc.modelcontext(model) if model is None else model
112
+ do_basinhopping = optimizer_method == "basinhopping"
113
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
114
+
115
+ if do_basinhopping:
116
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118
+ # if one isn't provided.
119
+
120
+ optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
121
+ minimizer_kwargs["method"] = optimizer_method
122
+
123
+ initial_point_dict = model.initial_point()
124
+ initial_point = DictToArrayBijection.map(initial_point_dict)
125
+ n_params = initial_point.data.shape[0]
126
+
127
+ var_params, objective = create_dadvi_graph(
128
+ model,
129
+ n_fixed_draws=n_fixed_draws,
130
+ random_seed=random_seed,
131
+ n_params=n_params,
132
+ )
133
+
134
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
135
+ optimizer_method, use_grad, use_hess, use_hessp
136
+ )
137
+
138
+ f_fused, f_hessp = scipy_optimize_funcs_from_loss(
139
+ loss=objective,
140
+ inputs=[var_params],
141
+ initial_point_dict=None,
142
+ use_grad=use_grad,
143
+ use_hessp=use_hessp,
144
+ use_hess=use_hess,
145
+ gradient_backend=gradient_backend,
146
+ compile_kwargs=compile_kwargs,
147
+ inputs_are_flat=True,
148
+ )
149
+
150
+ dadvi_initial_point = {
151
+ f"{var_name}_mu": np.zeros_like(value).ravel()
152
+ for var_name, value in initial_point_dict.items()
153
+ }
154
+ dadvi_initial_point.update(
155
+ {
156
+ f"{var_name}_sigma__log": np.zeros_like(value).ravel()
157
+ for var_name, value in initial_point_dict.items()
158
+ }
159
+ )
160
+
161
+ dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
162
+ args = optimizer_kwargs.pop("args", ())
163
+
164
+ if do_basinhopping:
165
+ if "args" not in minimizer_kwargs:
166
+ minimizer_kwargs["args"] = args
167
+ if "hessp" not in minimizer_kwargs:
168
+ minimizer_kwargs["hessp"] = f_hessp
169
+ if "method" not in minimizer_kwargs:
170
+ minimizer_kwargs["method"] = optimizer_method
171
+
172
+ result = basinhopping(
173
+ func=f_fused,
174
+ x0=dadvi_initial_point.data,
175
+ progressbar=progressbar,
176
+ minimizer_kwargs=minimizer_kwargs,
177
+ **optimizer_kwargs,
178
+ )
179
+
180
+ else:
181
+ result = minimize(
182
+ f=f_fused,
183
+ x0=dadvi_initial_point.data,
184
+ args=args,
185
+ method=optimizer_method,
186
+ hessp=f_hessp,
187
+ progressbar=progressbar,
188
+ **optimizer_kwargs,
189
+ )
190
+
191
+ raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
192
+
193
+ opt_var_params = result.x
194
+ opt_means, opt_log_sds = np.split(opt_var_params, 2)
195
+
196
+ # Make the draws:
197
+ generator = np.random.default_rng(seed=random_seed)
198
+ draws_raw = generator.standard_normal(size=(n_draws, n_params))
199
+
200
+ draws = opt_means + draws_raw * np.exp(opt_log_sds)
201
+ draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
202
+
203
+ idata = dadvi_result_to_idata(
204
+ draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
205
+ )
206
+
207
+ var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
208
+ var_name_to_model_var.update(
209
+ {f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
210
+ )
211
+
212
+ idata = add_optimizer_result_to_inference_data(
213
+ idata=idata,
214
+ result=result,
215
+ method=optimizer_method,
216
+ mu=raveled_optimized,
217
+ model=model,
218
+ var_name_to_model_var=var_name_to_model_var,
219
+ )
220
+
221
+ idata = add_data_to_inference_data(
222
+ idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
223
+ )
224
+
225
+ return idata
226
+
227
+
228
+ def create_dadvi_graph(
229
+ model: Model,
230
+ n_params: int,
231
+ n_fixed_draws: int = 30,
232
+ random_seed: RandomSeed = None,
233
+ ) -> tuple[TensorVariable, TensorVariable]:
234
+ """
235
+ Sets up the DADVI graph in pytensor and returns it.
236
+
237
+ Parameters
238
+ ----------
239
+ model : pm.Model
240
+ The PyMC model to be fit.
241
+
242
+ n_params: int
243
+ The total number of parameters in the model.
244
+
245
+ n_fixed_draws : int
246
+ The number of fixed draws to use.
247
+
248
+ random_seed: int
249
+ The random seed to use for the fixed draws.
250
+
251
+ Returns
252
+ -------
253
+ Tuple[TensorVariable, TensorVariable]
254
+ A tuple whose first element contains the variational parameters,
255
+ and whose second contains the DADVI objective.
256
+ """
257
+
258
+ # Make the fixed draws
259
+ generator = np.random.default_rng(seed=random_seed)
260
+ draws = generator.standard_normal(size=(n_fixed_draws, n_params))
261
+
262
+ inputs = model.continuous_value_vars + model.discrete_value_vars
263
+ initial_point_dict = model.initial_point()
264
+ logp = model.logp()
265
+
266
+ # Graph in terms of a flat input
267
+ [logp], flat_input = join_nonshared_inputs(
268
+ point=initial_point_dict, outputs=[logp], inputs=inputs
269
+ )
270
+
271
+ var_params = pt.vector(name="eta", shape=(2 * n_params,))
272
+
273
+ means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
274
+
275
+ draw_matrix = pt.constant(draws)
276
+ samples = means + pt.exp(log_sds) * draw_matrix
277
+
278
+ logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
279
+
280
+ mean_log_density = pt.mean(logp_vectorized_draws)
281
+ entropy = pt.sum(log_sds)
282
+
283
+ objective = -mean_log_density - entropy
284
+
285
+ return var_params, objective
286
+
287
+
288
+ def dadvi_result_to_idata(
289
+ unstacked_draws: xarray.Dataset,
290
+ model: Model,
291
+ include_transformed: bool = False,
292
+ progressbar: bool = True,
293
+ ):
294
+ """
295
+ Transforms the unconstrained draws back into the constrained space.
296
+
297
+ Parameters
298
+ ----------
299
+ unstacked_draws : xarray.Dataset
300
+ The draws to constrain back into the original space.
301
+
302
+ model : Model
303
+ The PyMC model the variables were derived from.
304
+
305
+ n_draws: int
306
+ The number of draws to return from the variational approximation.
307
+
308
+ include_transformed: bool
309
+ Whether or not to keep the unconstrained variables in the output.
310
+
311
+ progressbar: bool
312
+ Whether or not to show a progress bar during the transformation. Default is True.
313
+
314
+ Returns
315
+ -------
316
+ :class:`~arviz.InferenceData`
317
+ Draws from the original constrained parameters.
318
+ """
319
+
320
+ filtered_var_names = model.unobserved_value_vars
321
+ vars_to_sample = list(
322
+ get_default_varnames(filtered_var_names, include_transformed=include_transformed)
323
+ )
324
+ fn = pytensor.function(model.value_vars, vars_to_sample)
325
+ point_func = PointFunc(fn)
326
+
327
+ coords, dims = coords_and_dims_for_inferencedata(model)
328
+
329
+ transformed_result = apply_function_over_dataset(
330
+ point_func,
331
+ unstacked_draws,
332
+ output_var_names=[x.name for x in vars_to_sample],
333
+ coords=coords,
334
+ dims=dims,
335
+ progressbar=progressbar,
336
+ )
337
+
338
+ constrained_names = [
339
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
340
+ ]
341
+ all_varnames = [
342
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
343
+ ]
344
+ unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
345
+
346
+ idata = az.InferenceData(posterior=transformed_result[constrained_names])
347
+
348
+ if unconstrained_names and include_transformed:
349
+ idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
350
+
351
+ return idata
@@ -40,3 +40,8 @@ def fit(method: str, **kwargs) -> az.InferenceData:
40
40
  from pymc_extras.inference import fit_laplace
41
41
 
42
42
  return fit_laplace(**kwargs)
43
+
44
+ if method == "dadvi":
45
+ from pymc_extras.inference import fit_dadvi
46
+
47
+ return fit_dadvi(**kwargs)
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import pymc as pm
8
8
 
9
9
  from better_optimize import basinhopping, minimize
10
- from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
10
+ from better_optimize.constants import minimize_method
11
11
  from pymc.blocking import DictToArrayBijection, RaveledVars
12
12
  from pymc.initial_point import make_initial_point_fn
13
13
  from pymc.model.transform.optimization import freeze_dims_and_data
@@ -24,40 +24,12 @@ from pymc_extras.inference.laplace_approx.idata import (
24
24
  from pymc_extras.inference.laplace_approx.scipy_interface import (
25
25
  GradientBackend,
26
26
  scipy_optimize_funcs_from_loss,
27
+ set_optimizer_function_defaults,
27
28
  )
28
29
 
29
30
  _log = logging.getLogger(__name__)
30
31
 
31
32
 
32
- def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33
- method_info = MINIMIZE_MODE_KWARGS[method].copy()
34
-
35
- if use_hess and use_hessp:
36
- _log.warning(
37
- 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38
- 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39
- 'Setting "use_hess" to False.'
40
- )
41
- use_hess = False
42
-
43
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44
-
45
- if use_hessp is not None and use_hess is None:
46
- use_hess = not use_hessp
47
-
48
- elif use_hess is not None and use_hessp is None:
49
- use_hessp = not use_hess
50
-
51
- elif use_hessp is None and use_hess is None:
52
- use_hessp = method_info["uses_hessp"]
53
- use_hess = method_info["uses_hess"]
54
- if use_hessp and use_hess:
55
- # If a method could use either hess or hessp, we default to using hessp
56
- use_hess = False
57
-
58
- return use_grad, use_hess, use_hessp
59
-
60
-
61
33
  def get_nearest_psd(A: np.ndarray) -> np.ndarray:
62
34
  """
63
35
  Compute the nearest positive semi-definite matrix to a given matrix.
@@ -196,8 +168,10 @@ def find_MAP(
196
168
  jitter_rvs: list[TensorVariable] | None = None,
197
169
  progressbar: bool = True,
198
170
  include_transformed: bool = True,
171
+ freeze_model: bool = True,
199
172
  gradient_backend: GradientBackend = "pytensor",
200
173
  compile_kwargs: dict | None = None,
174
+ compute_hessian: bool = False,
201
175
  **optimizer_kwargs,
202
176
  ) -> (
203
177
  dict[str, np.ndarray]
@@ -237,8 +211,16 @@ def find_MAP(
237
211
  Whether to display a progress bar during optimization. Defaults to True.
238
212
  include_transformed: bool, optional
239
213
  Whether to include transformed variable values in the returned dictionary. Defaults to True.
214
+ freeze_model: bool, optional
215
+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
216
+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
217
+ True.
240
218
  gradient_backend: str, default "pytensor"
241
219
  Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
220
+ compute_hessian: bool
221
+ If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
222
+ InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
223
+ high-dimensional problems. Defaults to False.
242
224
  compile_kwargs: dict, optional
243
225
  Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
244
226
  **optimizer_kwargs
@@ -252,11 +234,13 @@ def find_MAP(
252
234
  Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
253
235
  latent variables, and optimizer results.
254
236
  """
255
- model = pm.modelcontext(model) if model is None else model
256
- frozen_model = freeze_dims_and_data(model)
257
237
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
238
+ model = pm.modelcontext(model) if model is None else model
239
+
240
+ if freeze_model:
241
+ model = freeze_dims_and_data(model)
258
242
 
259
- initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
243
+ initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
260
244
 
261
245
  do_basinhopping = method == "basinhopping"
262
246
  minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
@@ -274,8 +258,8 @@ def find_MAP(
274
258
  )
275
259
 
276
260
  f_fused, f_hessp = scipy_optimize_funcs_from_loss(
277
- loss=-frozen_model.logp(),
278
- inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
261
+ loss=-model.logp(),
262
+ inputs=model.continuous_value_vars + model.discrete_value_vars,
279
263
  initial_point_dict=DictToArrayBijection.rmap(initial_params),
280
264
  use_grad=use_grad,
281
265
  use_hess=use_hess,
@@ -316,14 +300,17 @@ def find_MAP(
316
300
  **optimizer_kwargs,
317
301
  )
318
302
 
319
- H_inv = _compute_inverse_hessian(
320
- optimizer_result=optimizer_result,
321
- optimal_point=None,
322
- f_fused=f_fused,
323
- f_hessp=f_hessp,
324
- use_hess=use_hess,
325
- method=method,
326
- )
303
+ if compute_hessian:
304
+ H_inv = _compute_inverse_hessian(
305
+ optimizer_result=optimizer_result,
306
+ optimal_point=None,
307
+ f_fused=f_fused,
308
+ f_hessp=f_hessp,
309
+ use_hess=use_hess,
310
+ method=method,
311
+ )
312
+ else:
313
+ H_inv = None
327
314
 
328
315
  raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329
316
  unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
@@ -336,12 +323,10 @@ def find_MAP(
336
323
  }
337
324
 
338
325
  idata = map_results_to_inference_data(
339
- map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
326
+ map_point=optimized_point, model=model, include_transformed=include_transformed
340
327
  )
341
328
 
342
- idata = add_fit_to_inference_data(
343
- idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
344
- )
329
+ idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
345
330
 
346
331
  idata = add_optimizer_result_to_inference_data(
347
332
  idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model