pymc-extras 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. pymc_extras/deserialize.py +10 -4
  2. pymc_extras/distributions/continuous.py +1 -1
  3. pymc_extras/distributions/histogram_utils.py +6 -4
  4. pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  5. pymc_extras/distributions/timeseries.py +14 -12
  6. pymc_extras/inference/dadvi/dadvi.py +149 -128
  7. pymc_extras/inference/laplace_approx/find_map.py +16 -39
  8. pymc_extras/inference/laplace_approx/idata.py +22 -4
  9. pymc_extras/inference/laplace_approx/laplace.py +196 -151
  10. pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  11. pymc_extras/inference/pathfinder/idata.py +517 -0
  12. pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  13. pymc_extras/inference/smc/sampling.py +2 -2
  14. pymc_extras/model/marginal/distributions.py +4 -2
  15. pymc_extras/model/marginal/graph_analysis.py +2 -2
  16. pymc_extras/model/marginal/marginal_model.py +12 -2
  17. pymc_extras/model_builder.py +9 -4
  18. pymc_extras/prior.py +203 -8
  19. pymc_extras/statespace/core/compile.py +1 -1
  20. pymc_extras/statespace/core/statespace.py +2 -1
  21. pymc_extras/statespace/filters/distributions.py +15 -13
  22. pymc_extras/statespace/filters/kalman_filter.py +24 -22
  23. pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  24. pymc_extras/statespace/filters/utilities.py +2 -5
  25. pymc_extras/statespace/models/DFM.py +12 -27
  26. pymc_extras/statespace/models/ETS.py +190 -198
  27. pymc_extras/statespace/models/SARIMAX.py +5 -17
  28. pymc_extras/statespace/models/VARMAX.py +15 -67
  29. pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  30. pymc_extras/statespace/models/structural/components/regression.py +4 -26
  31. pymc_extras/statespace/models/utilities.py +7 -0
  32. pymc_extras/utils/model_equivalence.py +2 -2
  33. pymc_extras/utils/prior.py +10 -14
  34. pymc_extras/utils/spline.py +4 -10
  35. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/METADATA +4 -4
  36. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/RECORD +38 -37
  37. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/WHEEL +1 -1
  38. {pymc_extras-0.5.0.dist-info → pymc_extras-0.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
13
13
 
14
14
  from pymc_extras.deserialize import deserialize
15
15
 
16
- prior_class_data = {
17
- "dist": "Normal",
18
- "kwargs": {"mu": 0, "sigma": 1}
19
- }
16
+ prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
20
17
  prior = deserialize(prior_class_data)
21
18
  # Prior("Normal", mu=0, sigma=1)
22
19
 
@@ -26,6 +23,7 @@ Register custom class deserialization:
26
23
 
27
24
  from pymc_extras.deserialize import register_deserialization
28
25
 
26
+
29
27
  class MyClass:
30
28
  def __init__(self, value: int):
31
29
  self.value = value
@@ -34,6 +32,7 @@ Register custom class deserialization:
34
32
  # Example of what the to_dict method might look like.
35
33
  return {"value": self.value}
36
34
 
35
+
37
36
  register_deserialization(
38
37
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
38
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
80
79
 
81
80
  from typing import Any
82
81
 
82
+
83
83
  class MyClass:
84
84
  def __init__(self, value: int):
85
85
  self.value = value
86
86
 
87
+
87
88
  from pymc_extras.deserialize import Deserializer
88
89
 
90
+
89
91
  def is_type(data: Any) -> bool:
90
92
  return data.keys() == {"value"} and isinstance(data["value"], int)
91
93
 
94
+
92
95
  def deserialize(data: dict) -> MyClass:
93
96
  return MyClass(value=data["value"])
94
97
 
98
+
95
99
  deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
100
 
97
101
  """
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196
200
 
197
201
  from pymc_extras.deserialize import register_deserialization
198
202
 
203
+
199
204
  class MyClass:
200
205
  def __init__(self, value: int):
201
206
  self.value = value
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
204
209
  # Example of what the to_dict method might look like.
205
210
  return {"value": self.value}
206
211
 
212
+
207
213
  register_deserialization(
208
214
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
215
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -265,7 +265,7 @@ class Chi:
265
265
  from pymc_extras.distributions import Chi
266
266
 
267
267
  with pm.Model():
268
- x = Chi('x', nu=1)
268
+ x = Chi("x", nu=1)
269
269
  """
270
270
 
271
271
  @staticmethod
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130
130
  ... m = pm.Normal("m", dims="tests")
131
131
  ... s = pm.LogNormal("s", dims="tests")
132
132
  ... pot = pmx.distributions.histogram_approximation(
133
- ... "pot", pm.Normal.dist(m, s),
134
- ... observed=measurements, n_quantiles=50
133
+ ... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135
134
  ... )
136
135
 
137
136
  For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143
142
  ... m = pm.Normal("m", dims="tests")
144
143
  ... s = pm.LogNormal("s", dims="tests")
145
144
  ... pot = pmx.distributions.histogram_approximation(
146
- ... "pot", pm.Normal.dist(m, s),
147
- ... observed=measurements, n_quantiles=50, zero_inflation=True
145
+ ... "pot",
146
+ ... pm.Normal.dist(m, s),
147
+ ... observed=measurements,
148
+ ... n_quantiles=50,
149
+ ... zero_inflation=True,
148
150
  ... )
149
151
  """
150
152
  try:
@@ -305,6 +305,7 @@ def R2D2M2CP(
305
305
  import pymc_extras as pmx
306
306
  import pymc as pm
307
307
  import numpy as np
308
+
308
309
  X = np.random.randn(10, 3)
309
310
  b = np.random.randn(3)
310
311
  y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339
340
  # "c" - a must have in the relation
340
341
  variables_importance=[10, 1, 34],
341
342
  # NOTE: try both
342
- centered=True
343
+ centered=True,
343
344
  )
344
345
  # intercept prior centering should be around prior predictive mean
345
346
  intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365
366
  r2_std=0.2,
366
367
  # NOTE: if you know where a variable should go
367
368
  # if you do not know, leave as 0.5
368
- centered=False
369
+ centered=False,
369
370
  )
370
371
  # intercept prior centering should be around prior predictive mean
371
372
  intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394
395
  # if you do not know, leave as 0.5
395
396
  positive_probs=[0.8, 0.5, 0.1],
396
397
  # NOTE: try both
397
- centered=True
398
+ centered=True,
398
399
  )
399
400
  intercept = y.mean()
400
401
  obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113
113
 
114
114
  with pm.Model() as markov_chain:
115
115
  P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116
- init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117
- markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116
+ init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117
+ markov_chain = pmx.DiscreteMarkovChain(
118
+ "markov_chain", P=P, init_dist=init_dist, shape=(100,)
119
+ )
118
120
 
119
121
  """
120
122
 
@@ -194,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
194
196
  state_rng = pytensor.shared(np.random.default_rng())
195
197
 
196
198
  def transition(*args):
197
- *states, transition_probs, old_rng = args
199
+ old_rng, *states, transition_probs = args
198
200
  p = transition_probs[tuple(states)]
199
201
  next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
200
- return next_state, {old_rng: next_rng}
202
+ return next_rng, next_state
201
203
 
202
- markov_chain, state_updates = pytensor.scan(
204
+ state_next_rng, markov_chain = pytensor.scan(
203
205
  transition,
204
- non_sequences=[P_, state_rng],
205
- outputs_info=_make_outputs_info(n_lags, init_dist_),
206
+ outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
207
+ non_sequences=[P_],
206
208
  n_steps=steps_,
207
209
  strict=True,
210
+ return_updates=False,
208
211
  )
209
212
 
210
- (state_next_rng,) = tuple(state_updates.values())
211
-
212
213
  discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
213
214
 
214
215
  discrete_mc_op = DiscreteMarkovChainRV(
@@ -237,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
237
238
  n_lags = op.n_lags
238
239
 
239
240
  def greedy_transition(*args):
240
- *states, transition_probs, old_rng = args
241
+ *states, transition_probs = args
241
242
  p = transition_probs[tuple(states)]
242
243
  return pt.argmax(p)
243
244
 
244
- chain_moment, moment_updates = pytensor.scan(
245
+ chain_moment = pytensor.scan(
245
246
  greedy_transition,
246
- non_sequences=[P, state_rng],
247
+ non_sequences=[P],
247
248
  outputs_info=_make_outputs_info(n_lags, init_dist),
248
249
  n_steps=steps,
249
250
  strict=True,
251
+ return_updates=False,
250
252
  )
251
253
  chain_moment = pt.concatenate([init_dist_moment, chain_moment])
252
254
  return chain_moment
@@ -3,43 +3,45 @@ import numpy as np
3
3
  import pymc
4
4
  import pytensor
5
5
  import pytensor.tensor as pt
6
- import xarray
7
6
 
8
- from better_optimize import minimize
7
+ from arviz import InferenceData
8
+ from better_optimize import basinhopping, minimize
9
9
  from better_optimize.constants import minimize_method
10
10
  from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
- from pymc.backends.arviz import (
12
- PointFunc,
13
- apply_function_over_dataset,
14
- coords_and_dims_for_inferencedata,
15
- )
16
- from pymc.util import RandomSeed, get_default_varnames
11
+ from pymc.blocking import RaveledVars
12
+ from pymc.util import RandomSeed
17
13
  from pytensor.tensor.variable import TensorVariable
18
14
 
19
- from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
15
+ from pymc_extras.inference.laplace_approx.idata import (
16
+ add_data_to_inference_data,
17
+ add_optimizer_result_to_inference_data,
18
+ )
19
+ from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
20
20
  from pymc_extras.inference.laplace_approx.scipy_interface import (
21
- _compile_functions_for_scipy_optimize,
21
+ scipy_optimize_funcs_from_loss,
22
+ set_optimizer_function_defaults,
22
23
  )
23
24
 
24
25
 
25
26
  def fit_dadvi(
26
27
  model: Model | None = None,
27
28
  n_fixed_draws: int = 30,
28
- random_seed: RandomSeed = None,
29
29
  n_draws: int = 1000,
30
- keep_untransformed: bool = False,
30
+ include_transformed: bool = False,
31
31
  optimizer_method: minimize_method = "trust-ncg",
32
- use_grad: bool = True,
33
- use_hessp: bool = True,
34
- use_hess: bool = False,
35
- **minimize_kwargs,
32
+ use_grad: bool | None = None,
33
+ use_hessp: bool | None = None,
34
+ use_hess: bool | None = None,
35
+ gradient_backend: str = "pytensor",
36
+ compile_kwargs: dict | None = None,
37
+ random_seed: RandomSeed = None,
38
+ progressbar: bool = True,
39
+ **optimizer_kwargs,
36
40
  ) -> az.InferenceData:
37
41
  """
38
- Does inference using deterministic ADVI (automatic differentiation
39
- variational inference), DADVI for short.
42
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
40
43
 
41
- For full details see the paper cited in the references:
42
- https://www.jmlr.org/papers/v25/23-1015.html
44
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
43
45
 
44
46
  Parameters
45
47
  ----------
@@ -47,46 +49,48 @@ def fit_dadvi(
47
49
  The PyMC model to be fit. If None, the current model context is used.
48
50
 
49
51
  n_fixed_draws : int
50
- The number of fixed draws to use for the optimisation. More
51
- draws will result in more accurate estimates, but also
52
- increase inference time. Usually, the default of 30 is a good
53
- tradeoff.between speed and accuracy.
52
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
53
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
54
54
 
55
55
  random_seed: int
56
- The random seed to use for the fixed draws. Running the optimisation
57
- twice with the same seed should arrive at the same result.
56
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
57
+ the same result.
58
58
 
59
59
  n_draws: int
60
60
  The number of draws to return from the variational approximation.
61
61
 
62
- keep_untransformed: bool
63
- Whether or not to keep the unconstrained variables (such as
64
- logs of positive-constrained parameters) in the output.
62
+ include_transformed: bool
63
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
64
+ output.
65
65
 
66
66
  optimizer_method: str
67
- Which optimization method to use. The function calls
68
- ``scipy.optimize.minimize``, so any of the methods there can
69
- be used. The default is trust-ncg, which uses second-order
70
- information and is generally very reliable. Other methods such
71
- as L-BFGS-B might be faster but potentially more brittle and
72
- may not converge exactly to the optimum.
73
-
74
- minimize_kwargs:
75
- Additional keyword arguments to pass to the
76
- ``scipy.optimize.minimize`` function. See the documentation of
77
- that function for details.
67
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
68
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
69
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
70
+ the optimum.
71
+
72
+ gradient_backend: str
73
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
78
74
 
79
- use_grad:
80
- If True, pass the gradient function to
81
- `scipy.optimize.minimize` (where it is referred to as `jac`).
75
+ compile_kwargs: dict, optional
76
+ Additional keyword arguments to pass to `pytensor.function`
82
77
 
83
- use_hessp:
78
+ use_grad: bool, optional
79
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
80
+
81
+ use_hessp: bool, optional
84
82
  If True, pass the hessian vector product to `scipy.optimize.minimize`.
85
83
 
86
- use_hess:
87
- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
- this is generally not recommended since its computation can be slow
89
- and memory-intensive if there are many parameters.
84
+ use_hess: bool, optional
85
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
86
+ computation can be slow and memory-intensive if there are many parameters.
87
+
88
+ progressbar: bool
89
+ Whether or not to show a progress bar during optimization. Default is True.
90
+
91
+ optimizer_kwargs:
92
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
93
+ that function for details.
90
94
 
91
95
  Returns
92
96
  -------
@@ -95,16 +99,25 @@ def fit_dadvi(
95
99
 
96
100
  References
97
101
  ----------
98
- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99
- Variational Inference with a Deterministic Objective: Faster, More
100
- Accurate, and Even More Black Box. Journal of Machine Learning
101
- Research, 25(18), 1–39.
102
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
103
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102
104
  """
103
105
 
104
106
  model = pymc.modelcontext(model) if model is None else model
107
+ do_basinhopping = optimizer_method == "basinhopping"
108
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
109
+
110
+ if do_basinhopping:
111
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
112
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
113
+ # if one isn't provided.
114
+
115
+ optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
116
+ minimizer_kwargs["method"] = optimizer_method
105
117
 
106
118
  initial_point_dict = model.initial_point()
107
- n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
119
+ initial_point = DictToArrayBijection.map(initial_point_dict)
120
+ n_params = initial_point.data.shape[0]
108
121
 
109
122
  var_params, objective = create_dadvi_graph(
110
123
  model,
@@ -113,44 +126,100 @@ def fit_dadvi(
113
126
  n_params=n_params,
114
127
  )
115
128
 
116
- f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117
- objective,
118
- [var_params],
119
- compute_grad=use_grad,
120
- compute_hessp=use_hessp,
121
- compute_hess=use_hess,
129
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
130
+ optimizer_method, use_grad, use_hess, use_hessp
122
131
  )
123
132
 
124
- derivative_kwargs = {}
125
-
126
- if use_grad:
127
- derivative_kwargs["jac"] = True
128
- if use_hessp:
129
- derivative_kwargs["hessp"] = f_hessp
130
- if use_hess:
131
- derivative_kwargs["hess"] = True
133
+ f_fused, f_hessp = scipy_optimize_funcs_from_loss(
134
+ loss=objective,
135
+ inputs=[var_params],
136
+ initial_point_dict=None,
137
+ use_grad=use_grad,
138
+ use_hessp=use_hessp,
139
+ use_hess=use_hess,
140
+ gradient_backend=gradient_backend,
141
+ compile_kwargs=compile_kwargs,
142
+ inputs_are_flat=True,
143
+ )
132
144
 
133
- result = minimize(
134
- f_fused,
135
- np.zeros(2 * n_params),
136
- method=optimizer_method,
137
- **derivative_kwargs,
138
- **minimize_kwargs,
145
+ dadvi_initial_point = {
146
+ f"{var_name}_mu": np.zeros_like(value).ravel()
147
+ for var_name, value in initial_point_dict.items()
148
+ }
149
+ dadvi_initial_point.update(
150
+ {
151
+ f"{var_name}_sigma__log": np.zeros_like(value).ravel()
152
+ for var_name, value in initial_point_dict.items()
153
+ }
139
154
  )
140
155
 
156
+ dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
157
+ args = optimizer_kwargs.pop("args", ())
158
+
159
+ if do_basinhopping:
160
+ if "args" not in minimizer_kwargs:
161
+ minimizer_kwargs["args"] = args
162
+ if "hessp" not in minimizer_kwargs:
163
+ minimizer_kwargs["hessp"] = f_hessp
164
+ if "method" not in minimizer_kwargs:
165
+ minimizer_kwargs["method"] = optimizer_method
166
+
167
+ result = basinhopping(
168
+ func=f_fused,
169
+ x0=dadvi_initial_point.data,
170
+ progressbar=progressbar,
171
+ minimizer_kwargs=minimizer_kwargs,
172
+ **optimizer_kwargs,
173
+ )
174
+
175
+ else:
176
+ result = minimize(
177
+ f=f_fused,
178
+ x0=dadvi_initial_point.data,
179
+ args=args,
180
+ method=optimizer_method,
181
+ hessp=f_hessp,
182
+ progressbar=progressbar,
183
+ **optimizer_kwargs,
184
+ )
185
+
186
+ raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
187
+
141
188
  opt_var_params = result.x
142
189
  opt_means, opt_log_sds = np.split(opt_var_params, 2)
143
190
 
144
- # Make the draws:
145
- generator = np.random.default_rng(seed=random_seed)
146
- draws_raw = generator.standard_normal(size=(n_draws, n_params))
191
+ posterior, unconstrained_posterior = draws_from_laplace_approx(
192
+ mean=opt_means,
193
+ standard_deviation=np.exp(opt_log_sds),
194
+ draws=n_draws,
195
+ model=model,
196
+ vectorize_draws=False,
197
+ return_unconstrained=include_transformed,
198
+ random_seed=random_seed,
199
+ )
200
+ idata = InferenceData(posterior=posterior)
201
+ if include_transformed:
202
+ idata.add_groups(unconstrained_posterior=unconstrained_posterior)
203
+
204
+ var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
205
+ var_name_to_model_var.update(
206
+ {f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
207
+ )
147
208
 
148
- draws = opt_means + draws_raw * np.exp(opt_log_sds)
149
- draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
209
+ idata = add_optimizer_result_to_inference_data(
210
+ idata=idata,
211
+ result=result,
212
+ method=optimizer_method,
213
+ mu=raveled_optimized,
214
+ model=model,
215
+ var_name_to_model_var=var_name_to_model_var,
216
+ )
150
217
 
151
- transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
218
+ idata = add_data_to_inference_data(
219
+ idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
220
+ )
152
221
 
153
- return transformed_draws
222
+ return idata
154
223
 
155
224
 
156
225
  def create_dadvi_graph(
@@ -211,51 +280,3 @@ def create_dadvi_graph(
211
280
  objective = -mean_log_density - entropy
212
281
 
213
282
  return var_params, objective
214
-
215
-
216
- def transform_draws(
217
- unstacked_draws: xarray.Dataset,
218
- model: Model,
219
- keep_untransformed: bool = False,
220
- ):
221
- """
222
- Transforms the unconstrained draws back into the constrained space.
223
-
224
- Parameters
225
- ----------
226
- unstacked_draws : xarray.Dataset
227
- The draws to constrain back into the original space.
228
-
229
- model : Model
230
- The PyMC model the variables were derived from.
231
-
232
- n_draws: int
233
- The number of draws to return from the variational approximation.
234
-
235
- keep_untransformed: bool
236
- Whether or not to keep the unconstrained variables in the output.
237
-
238
- Returns
239
- -------
240
- :class:`~arviz.InferenceData`
241
- Draws from the original constrained parameters.
242
- """
243
-
244
- filtered_var_names = model.unobserved_value_vars
245
- vars_to_sample = list(
246
- get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
247
- )
248
- fn = pytensor.function(model.value_vars, vars_to_sample)
249
- point_func = PointFunc(fn)
250
-
251
- coords, dims = coords_and_dims_for_inferencedata(model)
252
-
253
- transformed_result = apply_function_over_dataset(
254
- point_func,
255
- unstacked_draws,
256
- output_var_names=[x.name for x in vars_to_sample],
257
- coords=coords,
258
- dims=dims,
259
- )
260
-
261
- return transformed_result
@@ -7,7 +7,7 @@ import numpy as np
7
7
  import pymc as pm
8
8
 
9
9
  from better_optimize import basinhopping, minimize
10
- from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
10
+ from better_optimize.constants import minimize_method
11
11
  from pymc.blocking import DictToArrayBijection, RaveledVars
12
12
  from pymc.initial_point import make_initial_point_fn
13
13
  from pymc.model.transform.optimization import freeze_dims_and_data
@@ -24,40 +24,12 @@ from pymc_extras.inference.laplace_approx.idata import (
24
24
  from pymc_extras.inference.laplace_approx.scipy_interface import (
25
25
  GradientBackend,
26
26
  scipy_optimize_funcs_from_loss,
27
+ set_optimizer_function_defaults,
27
28
  )
28
29
 
29
30
  _log = logging.getLogger(__name__)
30
31
 
31
32
 
32
- def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33
- method_info = MINIMIZE_MODE_KWARGS[method].copy()
34
-
35
- if use_hess and use_hessp:
36
- _log.warning(
37
- 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38
- 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39
- 'Setting "use_hess" to False.'
40
- )
41
- use_hess = False
42
-
43
- use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44
-
45
- if use_hessp is not None and use_hess is None:
46
- use_hess = not use_hessp
47
-
48
- elif use_hess is not None and use_hessp is None:
49
- use_hessp = not use_hess
50
-
51
- elif use_hessp is None and use_hess is None:
52
- use_hessp = method_info["uses_hessp"]
53
- use_hess = method_info["uses_hess"]
54
- if use_hessp and use_hess:
55
- # If a method could use either hess or hessp, we default to using hessp
56
- use_hess = False
57
-
58
- return use_grad, use_hess, use_hessp
59
-
60
-
61
33
  def get_nearest_psd(A: np.ndarray) -> np.ndarray:
62
34
  """
63
35
  Compute the nearest positive semi-definite matrix to a given matrix.
@@ -196,6 +168,7 @@ def find_MAP(
196
168
  jitter_rvs: list[TensorVariable] | None = None,
197
169
  progressbar: bool = True,
198
170
  include_transformed: bool = True,
171
+ freeze_model: bool = True,
199
172
  gradient_backend: GradientBackend = "pytensor",
200
173
  compile_kwargs: dict | None = None,
201
174
  compute_hessian: bool = False,
@@ -238,6 +211,10 @@ def find_MAP(
238
211
  Whether to display a progress bar during optimization. Defaults to True.
239
212
  include_transformed: bool, optional
240
213
  Whether to include transformed variable values in the returned dictionary. Defaults to True.
214
+ freeze_model: bool, optional
215
+ If True, freeze_dims_and_data will be called on the model before compiling the loss functions. This is
216
+ sometimes necessary for JAX, and can sometimes improve performance by allowing constant folding. Defaults to
217
+ True.
241
218
  gradient_backend: str, default "pytensor"
242
219
  Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243
220
  compute_hessian: bool
@@ -257,11 +234,13 @@ def find_MAP(
257
234
  Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
258
235
  latent variables, and optimizer results.
259
236
  """
260
- model = pm.modelcontext(model) if model is None else model
261
- frozen_model = freeze_dims_and_data(model)
262
237
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
238
+ model = pm.modelcontext(model) if model is None else model
239
+
240
+ if freeze_model:
241
+ model = freeze_dims_and_data(model)
263
242
 
264
- initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
243
+ initial_params = _make_initial_point(model, initvals, random_seed, jitter_rvs)
265
244
 
266
245
  do_basinhopping = method == "basinhopping"
267
246
  minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
@@ -279,8 +258,8 @@ def find_MAP(
279
258
  )
280
259
 
281
260
  f_fused, f_hessp = scipy_optimize_funcs_from_loss(
282
- loss=-frozen_model.logp(),
283
- inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
261
+ loss=-model.logp(),
262
+ inputs=model.continuous_value_vars + model.discrete_value_vars,
284
263
  initial_point_dict=DictToArrayBijection.rmap(initial_params),
285
264
  use_grad=use_grad,
286
265
  use_hess=use_hess,
@@ -344,12 +323,10 @@ def find_MAP(
344
323
  }
345
324
 
346
325
  idata = map_results_to_inference_data(
347
- map_point=optimized_point, model=frozen_model, include_transformed=include_transformed
326
+ map_point=optimized_point, model=model, include_transformed=include_transformed
348
327
  )
349
328
 
350
- idata = add_fit_to_inference_data(
351
- idata=idata, mu=raveled_optimized, H_inv=H_inv, model=frozen_model
352
- )
329
+ idata = add_fit_to_inference_data(idata=idata, mu=raveled_optimized, H_inv=H_inv, model=model)
353
330
 
354
331
  idata = add_optimizer_result_to_inference_data(
355
332
  idata=idata, result=optimizer_result, method=method, mu=raveled_optimized, model=model