pymc-extras 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +4 -2
  6. pymc_extras/inference/dadvi/dadvi.py +162 -72
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +23 -6
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +61 -7
  13. pymc_extras/model/marginal/graph_analysis.py +2 -2
  14. pymc_extras/model_builder.py +9 -4
  15. pymc_extras/prior.py +203 -8
  16. pymc_extras/statespace/core/compile.py +1 -1
  17. pymc_extras/statespace/filters/kalman_filter.py +12 -11
  18. pymc_extras/statespace/filters/kalman_smoother.py +1 -3
  19. pymc_extras/statespace/filters/utilities.py +2 -5
  20. pymc_extras/statespace/models/DFM.py +12 -27
  21. pymc_extras/statespace/models/ETS.py +190 -198
  22. pymc_extras/statespace/models/SARIMAX.py +5 -17
  23. pymc_extras/statespace/models/VARMAX.py +15 -67
  24. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  25. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  26. pymc_extras/statespace/models/utilities.py +7 -0
  27. pymc_extras/utils/model_equivalence.py +2 -2
  28. pymc_extras/utils/prior.py +10 -14
  29. pymc_extras/utils/spline.py +4 -10
  30. {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/METADATA +3 -3
  31. {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/RECORD +33 -32
  32. {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/WHEEL +1 -1
  33. {pymc_extras-0.5.0.dist-info → pymc_extras-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
13
13
 
14
14
  from pymc_extras.deserialize import deserialize
15
15
 
16
- prior_class_data = {
17
- "dist": "Normal",
18
- "kwargs": {"mu": 0, "sigma": 1}
19
- }
16
+ prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
20
17
  prior = deserialize(prior_class_data)
21
18
  # Prior("Normal", mu=0, sigma=1)
22
19
 
@@ -26,6 +23,7 @@ Register custom class deserialization:
26
23
 
27
24
  from pymc_extras.deserialize import register_deserialization
28
25
 
26
+
29
27
  class MyClass:
30
28
  def __init__(self, value: int):
31
29
  self.value = value
@@ -34,6 +32,7 @@ Register custom class deserialization:
34
32
  # Example of what the to_dict method might look like.
35
33
  return {"value": self.value}
36
34
 
35
+
37
36
  register_deserialization(
38
37
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
38
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
80
79
 
81
80
  from typing import Any
82
81
 
82
+
83
83
  class MyClass:
84
84
  def __init__(self, value: int):
85
85
  self.value = value
86
86
 
87
+
87
88
  from pymc_extras.deserialize import Deserializer
88
89
 
90
+
89
91
  def is_type(data: Any) -> bool:
90
92
  return data.keys() == {"value"} and isinstance(data["value"], int)
91
93
 
94
+
92
95
  def deserialize(data: dict) -> MyClass:
93
96
  return MyClass(value=data["value"])
94
97
 
98
+
95
99
  deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
100
 
97
101
  """
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196
200
 
197
201
  from pymc_extras.deserialize import register_deserialization
198
202
 
203
+
199
204
  class MyClass:
200
205
  def __init__(self, value: int):
201
206
  self.value = value
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
204
209
  # Example of what the to_dict method might look like.
205
210
  return {"value": self.value}
206
211
 
212
+
207
213
  register_deserialization(
208
214
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
215
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -265,7 +265,7 @@ class Chi:
265
265
  from pymc_extras.distributions import Chi
266
266
 
267
267
  with pm.Model():
268
- x = Chi('x', nu=1)
268
+ x = Chi("x", nu=1)
269
269
  """
270
270
 
271
271
  @staticmethod
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130
130
  ... m = pm.Normal("m", dims="tests")
131
131
  ... s = pm.LogNormal("s", dims="tests")
132
132
  ... pot = pmx.distributions.histogram_approximation(
133
- ... "pot", pm.Normal.dist(m, s),
134
- ... observed=measurements, n_quantiles=50
133
+ ... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135
134
  ... )
136
135
 
137
136
  For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143
142
  ... m = pm.Normal("m", dims="tests")
144
143
  ... s = pm.LogNormal("s", dims="tests")
145
144
  ... pot = pmx.distributions.histogram_approximation(
146
- ... "pot", pm.Normal.dist(m, s),
147
- ... observed=measurements, n_quantiles=50, zero_inflation=True
145
+ ... "pot",
146
+ ... pm.Normal.dist(m, s),
147
+ ... observed=measurements,
148
+ ... n_quantiles=50,
149
+ ... zero_inflation=True,
148
150
  ... )
149
151
  """
150
152
  try:
@@ -305,6 +305,7 @@ def R2D2M2CP(
305
305
  import pymc_extras as pmx
306
306
  import pymc as pm
307
307
  import numpy as np
308
+
308
309
  X = np.random.randn(10, 3)
309
310
  b = np.random.randn(3)
310
311
  y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339
340
  # "c" - a must have in the relation
340
341
  variables_importance=[10, 1, 34],
341
342
  # NOTE: try both
342
- centered=True
343
+ centered=True,
343
344
  )
344
345
  # intercept prior centering should be around prior predictive mean
345
346
  intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365
366
  r2_std=0.2,
366
367
  # NOTE: if you know where a variable should go
367
368
  # if you do not know, leave as 0.5
368
- centered=False
369
+ centered=False,
369
370
  )
370
371
  # intercept prior centering should be around prior predictive mean
371
372
  intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394
395
  # if you do not know, leave as 0.5
395
396
  positive_probs=[0.8, 0.5, 0.1],
396
397
  # NOTE: try both
397
- centered=True
398
+ centered=True,
398
399
  )
399
400
  intercept = y.mean()
400
401
  obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113
113
 
114
114
  with pm.Model() as markov_chain:
115
115
  P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116
- init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117
- markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116
+ init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117
+ markov_chain = pmx.DiscreteMarkovChain(
118
+ "markov_chain", P=P, init_dist=init_dist, shape=(100,)
119
+ )
118
120
 
119
121
  """
120
122
 
@@ -5,7 +5,7 @@ import pytensor
5
5
  import pytensor.tensor as pt
6
6
  import xarray
7
7
 
8
- from better_optimize import minimize
8
+ from better_optimize import basinhopping, minimize
9
9
  from better_optimize.constants import minimize_method
10
10
  from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
11
  from pymc.backends.arviz import (
@@ -13,33 +13,40 @@ from pymc.backends.arviz import (
13
13
  apply_function_over_dataset,
14
14
  coords_and_dims_for_inferencedata,
15
15
  )
16
+ from pymc.blocking import RaveledVars
16
17
  from pymc.util import RandomSeed, get_default_varnames
17
18
  from pytensor.tensor.variable import TensorVariable
18
19
 
20
+ from pymc_extras.inference.laplace_approx.idata import (
21
+ add_data_to_inference_data,
22
+ add_optimizer_result_to_inference_data,
23
+ )
19
24
  from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
20
25
  from pymc_extras.inference.laplace_approx.scipy_interface import (
21
- _compile_functions_for_scipy_optimize,
26
+ scipy_optimize_funcs_from_loss,
27
+ set_optimizer_function_defaults,
22
28
  )
23
29
 
24
30
 
25
31
  def fit_dadvi(
26
32
  model: Model | None = None,
27
33
  n_fixed_draws: int = 30,
28
- random_seed: RandomSeed = None,
29
34
  n_draws: int = 1000,
30
- keep_untransformed: bool = False,
35
+ include_transformed: bool = False,
31
36
  optimizer_method: minimize_method = "trust-ncg",
32
- use_grad: bool = True,
33
- use_hessp: bool = True,
34
- use_hess: bool = False,
35
- **minimize_kwargs,
37
+ use_grad: bool | None = None,
38
+ use_hessp: bool | None = None,
39
+ use_hess: bool | None = None,
40
+ gradient_backend: str = "pytensor",
41
+ compile_kwargs: dict | None = None,
42
+ random_seed: RandomSeed = None,
43
+ progressbar: bool = True,
44
+ **optimizer_kwargs,
36
45
  ) -> az.InferenceData:
37
46
  """
38
- Does inference using deterministic ADVI (automatic differentiation
39
- variational inference), DADVI for short.
47
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
40
48
 
41
- For full details see the paper cited in the references:
42
- https://www.jmlr.org/papers/v25/23-1015.html
49
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
43
50
 
44
51
  Parameters
45
52
  ----------
@@ -47,46 +54,48 @@ def fit_dadvi(
47
54
  The PyMC model to be fit. If None, the current model context is used.
48
55
 
49
56
  n_fixed_draws : int
50
- The number of fixed draws to use for the optimisation. More
51
- draws will result in more accurate estimates, but also
52
- increase inference time. Usually, the default of 30 is a good
53
- tradeoff.between speed and accuracy.
57
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
58
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
54
59
 
55
60
  random_seed: int
56
- The random seed to use for the fixed draws. Running the optimisation
57
- twice with the same seed should arrive at the same result.
61
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
62
+ the same result.
58
63
 
59
64
  n_draws: int
60
65
  The number of draws to return from the variational approximation.
61
66
 
62
- keep_untransformed: bool
63
- Whether or not to keep the unconstrained variables (such as
64
- logs of positive-constrained parameters) in the output.
67
+ include_transformed: bool
68
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
69
+ output.
65
70
 
66
71
  optimizer_method: str
67
- Which optimization method to use. The function calls
68
- ``scipy.optimize.minimize``, so any of the methods there can
69
- be used. The default is trust-ncg, which uses second-order
70
- information and is generally very reliable. Other methods such
71
- as L-BFGS-B might be faster but potentially more brittle and
72
- may not converge exactly to the optimum.
73
-
74
- minimize_kwargs:
75
- Additional keyword arguments to pass to the
76
- ``scipy.optimize.minimize`` function. See the documentation of
77
- that function for details.
72
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
73
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
74
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
75
+ the optimum.
78
76
 
79
- use_grad:
80
- If True, pass the gradient function to
81
- `scipy.optimize.minimize` (where it is referred to as `jac`).
77
+ gradient_backend: str
78
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
82
79
 
83
- use_hessp:
80
+ compile_kwargs: dict, optional
81
+ Additional keyword arguments to pass to `pytensor.function`
82
+
83
+ use_grad: bool, optional
84
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
85
+
86
+ use_hessp: bool, optional
84
87
  If True, pass the hessian vector product to `scipy.optimize.minimize`.
85
88
 
86
- use_hess:
87
- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
- this is generally not recommended since its computation can be slow
89
- and memory-intensive if there are many parameters.
89
+ use_hess: bool, optional
90
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
91
+ computation can be slow and memory-intensive if there are many parameters.
92
+
93
+ progressbar: bool
94
+ Whether or not to show a progress bar during optimization. Default is True.
95
+
96
+ optimizer_kwargs:
97
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98
+ that function for details.
90
99
 
91
100
  Returns
92
101
  -------
@@ -95,16 +104,25 @@ def fit_dadvi(
95
104
 
96
105
  References
97
106
  ----------
98
- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99
- Variational Inference with a Deterministic Objective: Faster, More
100
- Accurate, and Even More Black Box. Journal of Machine Learning
101
- Research, 25(18), 1–39.
107
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
108
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102
109
  """
103
110
 
104
111
  model = pymc.modelcontext(model) if model is None else model
112
+ do_basinhopping = optimizer_method == "basinhopping"
113
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
114
+
115
+ if do_basinhopping:
116
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118
+ # if one isn't provided.
119
+
120
+ optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
121
+ minimizer_kwargs["method"] = optimizer_method
105
122
 
106
123
  initial_point_dict = model.initial_point()
107
- n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
124
+ initial_point = DictToArrayBijection.map(initial_point_dict)
125
+ n_params = initial_point.data.shape[0]
108
126
 
109
127
  var_params, objective = create_dadvi_graph(
110
128
  model,
@@ -113,31 +131,65 @@ def fit_dadvi(
113
131
  n_params=n_params,
114
132
  )
115
133
 
116
- f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117
- objective,
118
- [var_params],
119
- compute_grad=use_grad,
120
- compute_hessp=use_hessp,
121
- compute_hess=use_hess,
134
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
135
+ optimizer_method, use_grad, use_hess, use_hessp
122
136
  )
123
137
 
124
- derivative_kwargs = {}
125
-
126
- if use_grad:
127
- derivative_kwargs["jac"] = True
128
- if use_hessp:
129
- derivative_kwargs["hessp"] = f_hessp
130
- if use_hess:
131
- derivative_kwargs["hess"] = True
138
+ f_fused, f_hessp = scipy_optimize_funcs_from_loss(
139
+ loss=objective,
140
+ inputs=[var_params],
141
+ initial_point_dict=None,
142
+ use_grad=use_grad,
143
+ use_hessp=use_hessp,
144
+ use_hess=use_hess,
145
+ gradient_backend=gradient_backend,
146
+ compile_kwargs=compile_kwargs,
147
+ inputs_are_flat=True,
148
+ )
132
149
 
133
- result = minimize(
134
- f_fused,
135
- np.zeros(2 * n_params),
136
- method=optimizer_method,
137
- **derivative_kwargs,
138
- **minimize_kwargs,
150
+ dadvi_initial_point = {
151
+ f"{var_name}_mu": np.zeros_like(value).ravel()
152
+ for var_name, value in initial_point_dict.items()
153
+ }
154
+ dadvi_initial_point.update(
155
+ {
156
+ f"{var_name}_sigma__log": np.zeros_like(value).ravel()
157
+ for var_name, value in initial_point_dict.items()
158
+ }
139
159
  )
140
160
 
161
+ dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
162
+ args = optimizer_kwargs.pop("args", ())
163
+
164
+ if do_basinhopping:
165
+ if "args" not in minimizer_kwargs:
166
+ minimizer_kwargs["args"] = args
167
+ if "hessp" not in minimizer_kwargs:
168
+ minimizer_kwargs["hessp"] = f_hessp
169
+ if "method" not in minimizer_kwargs:
170
+ minimizer_kwargs["method"] = optimizer_method
171
+
172
+ result = basinhopping(
173
+ func=f_fused,
174
+ x0=dadvi_initial_point.data,
175
+ progressbar=progressbar,
176
+ minimizer_kwargs=minimizer_kwargs,
177
+ **optimizer_kwargs,
178
+ )
179
+
180
+ else:
181
+ result = minimize(
182
+ f=f_fused,
183
+ x0=dadvi_initial_point.data,
184
+ args=args,
185
+ method=optimizer_method,
186
+ hessp=f_hessp,
187
+ progressbar=progressbar,
188
+ **optimizer_kwargs,
189
+ )
190
+
191
+ raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
192
+
141
193
  opt_var_params = result.x
142
194
  opt_means, opt_log_sds = np.split(opt_var_params, 2)
143
195
 
@@ -148,9 +200,29 @@ def fit_dadvi(
148
200
  draws = opt_means + draws_raw * np.exp(opt_log_sds)
149
201
  draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150
202
 
151
- transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
203
+ idata = dadvi_result_to_idata(
204
+ draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
205
+ )
152
206
 
153
- return transformed_draws
207
+ var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
208
+ var_name_to_model_var.update(
209
+ {f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
210
+ )
211
+
212
+ idata = add_optimizer_result_to_inference_data(
213
+ idata=idata,
214
+ result=result,
215
+ method=optimizer_method,
216
+ mu=raveled_optimized,
217
+ model=model,
218
+ var_name_to_model_var=var_name_to_model_var,
219
+ )
220
+
221
+ idata = add_data_to_inference_data(
222
+ idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
223
+ )
224
+
225
+ return idata
154
226
 
155
227
 
156
228
  def create_dadvi_graph(
@@ -213,10 +285,11 @@ def create_dadvi_graph(
213
285
  return var_params, objective
214
286
 
215
287
 
216
- def transform_draws(
288
+ def dadvi_result_to_idata(
217
289
  unstacked_draws: xarray.Dataset,
218
290
  model: Model,
219
- keep_untransformed: bool = False,
291
+ include_transformed: bool = False,
292
+ progressbar: bool = True,
220
293
  ):
221
294
  """
222
295
  Transforms the unconstrained draws back into the constrained space.
@@ -232,9 +305,12 @@ def transform_draws(
232
305
  n_draws: int
233
306
  The number of draws to return from the variational approximation.
234
307
 
235
- keep_untransformed: bool
308
+ include_transformed: bool
236
309
  Whether or not to keep the unconstrained variables in the output.
237
310
 
311
+ progressbar: bool
312
+ Whether or not to show a progress bar during the transformation. Default is True.
313
+
238
314
  Returns
239
315
  -------
240
316
  :class:`~arviz.InferenceData`
@@ -243,7 +319,7 @@ def transform_draws(
243
319
 
244
320
  filtered_var_names = model.unobserved_value_vars
245
321
  vars_to_sample = list(
246
- get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
322
+ get_default_varnames(filtered_var_names, include_transformed=include_transformed)
247
323
  )
248
324
  fn = pytensor.function(model.value_vars, vars_to_sample)
249
325
  point_func = PointFunc(fn)
@@ -256,6 +332,20 @@ def transform_draws(
256
332
  output_var_names=[x.name for x in vars_to_sample],
257
333
  coords=coords,
258
334
  dims=dims,
335
+ progressbar=progressbar,
259
336
  )
260
337
 
261
- return transformed_result
338
+ constrained_names = [
339
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
340
+ ]
341
+ all_varnames = [
342
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
343
+ ]
344
+ unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
345
+
346
+ idata = az.InferenceData(posterior=transformed_result[constrained_names])
347
+
348
+ if unconstrained_names and include_transformed:
349
+ idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
350
+
351
+ return idata
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import pymc as pm
8
8
 
9
9
  from better_optimize import basinhopping, minimize
10
- from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
10
+ from better_optimize.constants import minimize_method
11
11
  from pymc.blocking import DictToArrayBijection, RaveledVars
12
12
  from pymc.initial_point import make_initial_point_fn
13
13
  from pymc.model.transform.optimization import freeze_dims_and_data
@@ -24,40 +24,12 @@ from pymc_extras.inference.laplace_approx.idata import (
24
24
  from pymc_extras.inference.laplace_approx.scipy_interface import (
25
25
  GradientBackend,
26
26
  scipy_optimize_funcs_from_loss,
27
+ set_optimizer_function_defaults,
27
28
  )
28
29
 
29
30
  _log = logging.getLogger(__name__)
30
31
 
31
32
 
32
- def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33
- method_info = MINIMIZE_MODE_KWARGS[method].copy()
34
-
35
- if use_hess and use_hessp:
36
- _log.warning(
37
- 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38
- 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39
- 'Setting "use_hess" to False.'
40
- )
41
- use_hess = False
42
-
43
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44
-
45
- if use_hessp is not None and use_hess is None:
46
- use_hess = not use_hessp
47
-
48
- elif use_hess is not None and use_hessp is None:
49
- use_hessp = not use_hess
50
-
51
- elif use_hessp is None and use_hess is None:
52
- use_hessp = method_info["uses_hessp"]
53
- use_hess = method_info["uses_hess"]
54
- if use_hessp and use_hess:
55
- # If a method could use either hess or hessp, we default to using hessp
56
- use_hess = False
57
-
58
- return use_grad, use_hess, use_hessp
59
-
60
-
61
33
  def get_nearest_psd(A: np.ndarray) -> np.ndarray:
62
34
  """
63
35
  Compute the nearest positive semi-definite matrix to a given matrix.
@@ -196,6 +168,7 @@ def find_MAP(
196
168
  jitter_rvs: list[TensorVariable] | None = None,
197
169
  progressbar: bool = True,
198
170
  include_transformed: bool = True,
171
+ freeze_model: bool = True,
199
172
  gradient_backend: GradientBackend = "pytensor",
200
173
  compile_kwargs: dict | None = None,
201
174
  compute_hessian: bool = False,
@@ -238,6 +211,10 @@ def find_MAP(
238
211
  Whether to display a progress bar during optimization. Defaults to True.
239
212
  include_transformed: bool, optional
240
213
  Whether to include transformed variable values in the returned dictionary. Defaults to True.
214
+ freeze_model: bool, optional
215
+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
216
+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
217
+ True.
241
218
  gradient_backend: str, default "pytensor"
242
219
  Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243
220
  compute_hessian: bool
@@ -257,11 +234,13 @@ def find_MAP(
257
234
  Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
258
235
  latent variables, and optimizer results.
259
236
  """
260
- model = pm.modelcontext(model) if model is None else model
261
- frozen_model = freeze_dims_and_data(model)
262
237
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
238
+ model = pm.modelcontext(model) if model is None else model
239
+
240
+ if freeze_model:
241
+ model = freeze_dims_and_data(model)
263
242
 
264
- initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
243
+ initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
265
244
 
266
245
  do_basinhopping = method == "basinhopping"
267
246
  minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
@@ -279,8 +258,8 @@ def find_MAP(
279
258
  )
280
259
 
281
260
  f_fused, f_hessp = scipy_optimize_funcs_from_loss(
282
- loss=-frozen_model.logp(),
283
- inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
261
+ loss=-model.logp(),
262
+ inputs=model.continuous_value_vars + model.discrete_value_vars,
284
263
  initial_point_dict=DictToArrayBijection.rmap(initial_params),
285
264
  use_grad=use_grad,
286
265
  use_hess=use_hess,
@@ -344,12 +323,10 @@ def find_MAP(
344
323
  }
345
324
 
346
325
  idata = map_results_to_inference_data(
347
- map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
326
+ map_point=optimized_point, model=model, include_transformed=include_transformed
348
327
  )
349
328
 
350
- idata = add_fit_to_inference_data(
351
- idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
352
- )
329
+ idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
353
330
 
354
331
  idata = add_optimizer_result_to_inference_data(
355
332
  idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model