pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pymc_extras/inference/__init__.py +2 -2
  2. pymc_extras/inference/fit.py +1 -1
  3. pymc_extras/inference/laplace_approx/__init__.py +0 -0
  4. pymc_extras/inference/laplace_approx/find_map.py +354 -0
  5. pymc_extras/inference/laplace_approx/idata.py +393 -0
  6. pymc_extras/inference/laplace_approx/laplace.py +453 -0
  7. pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
  8. pymc_extras/inference/pathfinder/pathfinder.py +3 -4
  9. pymc_extras/linearmodel.py +3 -1
  10. pymc_extras/model/marginal/graph_analysis.py +4 -0
  11. pymc_extras/prior.py +38 -6
  12. pymc_extras/statespace/core/statespace.py +78 -52
  13. pymc_extras/statespace/filters/kalman_smoother.py +1 -1
  14. pymc_extras/statespace/models/structural/__init__.py +21 -0
  15. pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  16. pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
  17. pymc_extras/statespace/models/structural/components/cycle.py +305 -0
  18. pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
  19. pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
  20. pymc_extras/statespace/models/structural/components/regression.py +228 -0
  21. pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
  22. pymc_extras/statespace/models/structural/core.py +900 -0
  23. pymc_extras/statespace/models/structural/utils.py +16 -0
  24. pymc_extras/statespace/models/utilities.py +285 -0
  25. pymc_extras/statespace/utils/constants.py +4 -4
  26. pymc_extras/statespace/utils/data_tools.py +3 -2
  27. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
  28. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
  29. pymc_extras/inference/find_map.py +0 -496
  30. pymc_extras/inference/laplace.py +0 -583
  31. pymc_extras/statespace/models/structural.py +0 -1679
  32. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
  33. {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,453 @@
1
+ # Copyright 2024 The PyMC Developers
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ from collections.abc import Callable
19
+ from functools import partial
20
+ from typing import Literal
21
+ from typing import cast as type_cast
22
+
23
+ import arviz as az
24
+ import numpy as np
25
+ import pymc as pm
26
+ import pytensor
27
+ import pytensor.tensor as pt
28
+ import xarray as xr
29
+
30
+ from better_optimize.constants import minimize_method
31
+ from numpy.typing import ArrayLike
32
+ from pymc.blocking import DictToArrayBijection
33
+ from pymc.model.transform.optimization import freeze_dims_and_data
34
+ from pymc.pytensorf import join_nonshared_inputs
35
+ from pymc.util import get_default_varnames
36
+ from pytensor.graph import vectorize_graph
37
+ from pytensor.tensor import TensorVariable
38
+ from pytensor.tensor.optimize import minimize
39
+ from pytensor.tensor.type import Variable
40
+
41
+ from pymc_extras.inference.laplace_approx.find_map import (
42
+ _compute_inverse_hessian,
43
+ _make_initial_point,
44
+ find_MAP,
45
+ )
46
+ from pymc_extras.inference.laplace_approx.scipy_interface import (
47
+ GradientBackend,
48
+ scipy_optimize_funcs_from_loss,
49
+ )
50
+
51
+ _log = logging.getLogger(__name__)
52
+
53
+
54
+ def get_conditional_gaussian_approximation(
55
+ x: TensorVariable,
56
+ Q: TensorVariable | ArrayLike,
57
+ mu: TensorVariable | ArrayLike,
58
+ args: list[TensorVariable] | None = None,
59
+ model: pm.Model | None = None,
60
+ method: minimize_method = "BFGS",
61
+ use_jac: bool = True,
62
+ use_hess: bool = False,
63
+ optimizer_kwargs: dict | None = None,
64
+ ) -> Callable:
65
+ """
66
+ Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
67
+
68
+ That is:
69
+ y | x, sigma ~ N(Ax, sigma^2 W)
70
+ x | params ~ N(mu, Q(params)^-1)
71
+
72
+ We seek to estimate log(p(x | y, params)):
73
+
74
+ log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
75
+
76
+ Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
77
+
78
+ This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
79
+
80
+ Thus:
81
+
82
+ 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
83
+
84
+ 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
85
+
86
+ Parameters
87
+ ----------
88
+ x: TensorVariable
89
+ The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
90
+ Q: TensorVariable | ArrayLike
91
+ The precision matrix of the latent field x.
92
+ mu: TensorVariable | ArrayLike
93
+ The mean of the latent field x.
94
+ args: list[TensorVariable]
95
+ Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
96
+ model: Model
97
+ PyMC model to use.
98
+ method: minimize_method
99
+ Which minimization algorithm to use.
100
+ use_jac: bool
101
+ If true, the minimizer will compute the gradient of log(p(x | y, params)).
102
+ use_hess: bool
103
+ If true, the minimizer will compute the Hessian log(p(x | y, params)).
104
+ optimizer_kwargs: dict
105
+ Kwargs to pass to scipy.optimize.minimize.
106
+
107
+ Returns
108
+ -------
109
+ f: Callable
110
+ A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
111
+ """
112
+ model = pm.modelcontext(model)
113
+
114
+ if args is None:
115
+ args = model.continuous_value_vars + model.discrete_value_vars
116
+
117
+ # f = log(p(y | x, params))
118
+ f_x = model.logp()
119
+ jac = pytensor.gradient.grad(f_x, x)
120
+ hess = pytensor.gradient.jacobian(jac.flatten(), x)
121
+
122
+ # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
123
+ log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
124
+
125
+ # Maximize log(p(x | y, params)) wrt x to find mode x0
126
+ x0, _ = minimize(
127
+ objective=-log_x_posterior,
128
+ x=x,
129
+ method=method,
130
+ jac=use_jac,
131
+ hess=use_hess,
132
+ optimizer_kwargs=optimizer_kwargs,
133
+ )
134
+
135
+ # require f'(x0) and f''(x0) for Laplace approx
136
+ jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
137
+ hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
138
+
139
+ # Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
140
+ _, logdetQ = pt.nlinalg.slogdet(Q)
141
+ conditional_gaussian_approx = (
142
+ -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
143
+ )
144
+
145
+ # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
146
+ # far from the mode x0 or in a neighbourhood which results in poor convergence.
147
+ return pytensor.function(args, [x0, conditional_gaussian_approx])
148
+
149
+
150
+ def _unconstrained_vector_to_constrained_rvs(model):
151
+ outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
152
+ constrained_names = [
153
+ x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
154
+ ]
155
+ names = [x.name for x in outputs]
156
+
157
+ unconstrained_names = [name for name in names if name not in constrained_names]
158
+
159
+ new_outputs, unconstrained_vector = join_nonshared_inputs(
160
+ model.initial_point(),
161
+ inputs=model.value_vars,
162
+ outputs=outputs,
163
+ )
164
+
165
+ constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
166
+ value_rvs = [x for x in new_outputs if x not in constrained_rvs]
167
+
168
+ unconstrained_vector.name = "unconstrained_vector"
169
+
170
+ # Redo the names list to ensure it is sorted to match the return order
171
+ names = [*constrained_names, *unconstrained_names]
172
+
173
+ return names, constrained_rvs, value_rvs, unconstrained_vector
174
+
175
+
176
+ def model_to_laplace_approx(
177
+ model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
178
+ ):
179
+ initial_point = model.initial_point()
180
+ raveled_vars = DictToArrayBijection.map(initial_point)
181
+ raveled_shape = raveled_vars.data.shape[0]
182
+
183
+ # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
184
+ # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
185
+ names, constrained_rvs, value_rvs, unconstrained_vector = (
186
+ _unconstrained_vector_to_constrained_rvs(model)
187
+ )
188
+
189
+ coords = model.coords | {
190
+ "temp_chain": np.arange(chains),
191
+ "temp_draw": np.arange(draws),
192
+ "unpacked_variable_names": unpacked_variable_names,
193
+ }
194
+
195
+ with pm.Model(coords=coords, model=None) as laplace_model:
196
+ mu = pm.Flat("mean_vector", shape=(raveled_shape,))
197
+ cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape))
198
+ laplace_approximation = pm.MvNormal(
199
+ "laplace_approximation",
200
+ mu=mu,
201
+ cov=cov,
202
+ dims=["temp_chain", "temp_draw", "unpacked_variable_names"],
203
+ method="svd",
204
+ )
205
+
206
+ cast_to_var = partial(type_cast, Variable)
207
+ batched_rvs = vectorize_graph(
208
+ type_cast(list[Variable], constrained_rvs),
209
+ replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
210
+ )
211
+
212
+ for name, batched_rv in zip(names, batched_rvs):
213
+ batch_dims = ("temp_chain", "temp_draw")
214
+ if batched_rv.ndim == 2:
215
+ dims = batch_dims
216
+ elif name in model.named_vars_to_dims:
217
+ dims = (*batch_dims, *model.named_vars_to_dims[name])
218
+ else:
219
+ dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
220
+ initval = initial_point.get(name, None)
221
+ dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
222
+ laplace_model.add_coords(
223
+ {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
224
+ )
225
+
226
+ pm.Deterministic(name, batched_rv, dims=dims)
227
+
228
+ return laplace_model
229
+
230
+
231
+ def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
232
+ """
233
+ The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
234
+ in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
235
+ single vector, it's not easy to work with.
236
+
237
+ This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
238
+ coordinates, where possible.
239
+ """
240
+ initial_point = DictToArrayBijection.map(model.initial_point())
241
+
242
+ cursor = 0
243
+ unstacked_laplace_draws = {}
244
+ coords = model.coords | {"chain": range(chains), "draw": range(draws)}
245
+
246
+ # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
247
+ # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
248
+ # add an arviz-style default dim and label.
249
+ for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
250
+ rv_dims = []
251
+ for i, dim in enumerate(
252
+ model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
253
+ ):
254
+ if coords.get(dim) and shape[i] == len(coords[dim]):
255
+ rv_dims.append(dim)
256
+ else:
257
+ rv_dims.append(f"{name}_dim_{i}")
258
+ coords[f"{name}_dim_{i}"] = np.arange(shape[i])
259
+
260
+ dims = ("chain", "draw", *rv_dims)
261
+
262
+ values = (
263
+ laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
264
+ )
265
+ unstacked_laplace_draws[name] = xr.DataArray(
266
+ values, dims=dims, coords={dim: list(coords[dim]) for dim in dims}
267
+ )
268
+
269
+ cursor += size
270
+
271
+ unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
272
+
273
+ return unstacked_laplace_draws
274
+
275
+
276
+ def fit_laplace(
277
+ optimize_method: minimize_method | Literal["basinhopping"] = "BFGS",
278
+ *,
279
+ model: pm.Model | None = None,
280
+ use_grad: bool | None = None,
281
+ use_hessp: bool | None = None,
282
+ use_hess: bool | None = None,
283
+ initvals: dict | None = None,
284
+ random_seed: int | np.random.Generator | None = None,
285
+ jitter_rvs: list[pt.TensorVariable] | None = None,
286
+ progressbar: bool = True,
287
+ include_transformed: bool = True,
288
+ gradient_backend: GradientBackend = "pytensor",
289
+ chains: int = 2,
290
+ draws: int = 500,
291
+ optimizer_kwargs: dict | None = None,
292
+ compile_kwargs: dict | None = None,
293
+ ) -> az.InferenceData:
294
+ """
295
+ Create a Laplace (quadratic) approximation for a posterior distribution.
296
+
297
+ This function generates a Laplace approximation for a given posterior distribution using a specified
298
+ number of draws. This is useful for obtaining a parametric approximation to the posterior distribution
299
+ that can be used for further analysis.
300
+
301
+ Parameters
302
+ ----------
303
+ model : pm.Model
304
+ The PyMC model to be fit. If None, the current model context is used.
305
+ optimize_method : str
306
+ The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
307
+ trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
308
+
309
+ See scipy.optimize.minimize documentation for details.
310
+ use_grad : bool | None, optional
311
+ Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
312
+ the ``method``.
313
+ use_hessp : bool | None, optional
314
+ Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
315
+ the ``method``.
316
+ use_hess : bool | None, optional
317
+ Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
318
+ the ``method``.
319
+ initvals : None | dict, optional
320
+ Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
321
+ If None, the model's default initial values are used.
322
+ random_seed : None | int | np.random.Generator, optional
323
+ Seed for the random number generator or a numpy Generator for reproducibility
324
+ jitter_rvs : list of TensorVariables, optional
325
+ Variables whose initial values should be jittered. If None, all variables are jittered.
326
+ progressbar : bool, optional
327
+ Whether to display a progress bar during optimization. Defaults to True.
328
+ include_transformed: bool, default True
329
+ Whether to include transformed variables in the output. If True, transformed variables will be included in the
330
+ output InferenceData object. If False, only the original variables will be included.
331
+ gradient_backend: str, default "pytensor"
332
+ The backend to use for gradient computations. Must be one of "pytensor" or "jax".
333
+ chains: int, default: 2
334
+ The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
335
+ because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
336
+ compatible with the ArviZ library.
337
+ draws: int, default: 500
338
+ The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
339
+ optimizer_kwargs
340
+ Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
341
+ ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
342
+ ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
343
+ compile_kwargs: dict, optional
344
+ Additional keyword arguments to pass to pytensor.function.
345
+
346
+ Returns
347
+ -------
348
+ :class:`~arviz.InferenceData`
349
+ An InferenceData object containing the approximated posterior samples.
350
+
351
+ Examples
352
+ --------
353
+ >>> from pymc_extras.inference import fit_laplace
354
+ >>> import numpy as np
355
+ >>> import pymc as pm
356
+ >>> import arviz as az
357
+ >>> y = np.array([2642, 3503, 4358]*10)
358
+ >>> with pm.Model() as m:
359
+ >>> logsigma = pm.Uniform("logsigma", 1, 100)
360
+ >>> mu = pm.Uniform("mu", -10000, 10000)
361
+ >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y)
362
+ >>> idata = fit_laplace()
363
+
364
+ Notes
365
+ -----
366
+ This method of approximation may not be suitable for all types of posterior distributions,
367
+ especially those with significant skewness or multimodality.
368
+
369
+ See Also
370
+ --------
371
+ fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m)
372
+ will forward the call to 'fit_laplace'.
373
+
374
+ """
375
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
376
+ optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
377
+ model = pm.modelcontext(model) if model is None else model
378
+
379
+ idata = find_MAP(
380
+ method=optimize_method,
381
+ model=model,
382
+ use_grad=use_grad,
383
+ use_hessp=use_hessp,
384
+ use_hess=use_hess,
385
+ initvals=initvals,
386
+ random_seed=random_seed,
387
+ jitter_rvs=jitter_rvs,
388
+ progressbar=progressbar,
389
+ include_transformed=include_transformed,
390
+ gradient_backend=gradient_backend,
391
+ compile_kwargs=compile_kwargs,
392
+ **optimizer_kwargs,
393
+ )
394
+
395
+ unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
396
+
397
+ if "covariance_matrix" not in idata.fit:
398
+ # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
399
+ # we have to go back and compute the Hessian at the MAP point now.
400
+ frozen_model = freeze_dims_and_data(model)
401
+ initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
402
+
403
+ _, f_hessp = scipy_optimize_funcs_from_loss(
404
+ loss=-frozen_model.logp(jacobian=False),
405
+ inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
406
+ initial_point_dict=DictToArrayBijection.rmap(initial_params),
407
+ use_grad=False,
408
+ use_hess=False,
409
+ use_hessp=True,
410
+ gradient_backend=gradient_backend,
411
+ compile_kwargs=compile_kwargs,
412
+ )
413
+ H_inv = _compute_inverse_hessian(
414
+ optimizer_result=None,
415
+ optimal_point=idata.fit.mean_vector.values,
416
+ f_fused=None,
417
+ f_hessp=f_hessp,
418
+ use_hess=False,
419
+ method=optimize_method,
420
+ )
421
+
422
+ idata.fit["covariance_matrix"] = xr.DataArray(
423
+ H_inv,
424
+ dims=("rows", "columns"),
425
+ coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
426
+ )
427
+
428
+ with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model:
429
+ new_posterior = (
430
+ pm.sample_posterior_predictive(
431
+ idata.fit.expand_dims(chain=[0], draw=[0]),
432
+ extend_inferencedata=False,
433
+ random_seed=random_seed,
434
+ var_names=[
435
+ "laplace_approximation",
436
+ *[x.name for x in laplace_model.deterministics],
437
+ ],
438
+ )
439
+ .posterior_predictive.squeeze(["chain", "draw"])
440
+ .drop_vars(["chain", "draw"])
441
+ .rename({"temp_chain": "chain", "temp_draw": "draw"})
442
+ )
443
+
444
+ if include_transformed:
445
+ idata.unconstrained_posterior = unstack_laplace_draws(
446
+ new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
447
+ )
448
+
449
+ idata.posterior = new_posterior.drop_vars(
450
+ ["laplace_approximation", "unpacked_variable_names"]
451
+ )
452
+
453
+ return idata
@@ -0,0 +1,242 @@
1
+ from collections.abc import Callable
2
+ from importlib.util import find_spec
3
+ from typing import Literal, get_args
4
+
5
+ import numpy as np
6
+ import pymc as pm
7
+ import pytensor
8
+
9
+ from pymc import join_nonshared_inputs
10
+ from pytensor import tensor as pt
11
+ from pytensor.compile import Function
12
+ from pytensor.tensor import TensorVariable
13
+
14
+ GradientBackend = Literal["pytensor", "jax"]
15
+ VALID_BACKENDS = get_args(GradientBackend)
16
+
17
+
18
+ def _compile_grad_and_hess_to_jax(
19
+ f_fused: Function, use_hess: bool, use_hessp: bool
20
+ ) -> tuple[Callable | None, Callable | None]:
21
+ """
22
+ Compile loss function gradients using JAX.
23
+
24
+ Parameters
25
+ ----------
26
+ f_fused: Function
27
+ The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
28
+ compiled with mode="JAX".
29
+ use_hess: bool
30
+ Whether to compile a function to compute the hessian of the loss function.
31
+ use_hessp: bool
32
+ Whether to compile a function to compute the hessian-vector product of the loss function.
33
+
34
+ Returns
35
+ -------
36
+ f_fused: Callable
37
+ The compiled loss function and gradient function, which may also compute the hessian if requested.
38
+ f_hessp: Callable | None
39
+ The compiled hessian-vector product function, or None if use_hessp is False.
40
+ """
41
+ import jax
42
+
43
+ f_hessp = None
44
+
45
+ orig_loss_fn = f_fused.vm.jit_fn
46
+
47
+ if use_hess:
48
+
49
+ @jax.jit
50
+ def loss_fn_fused(x):
51
+ loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
52
+ hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x)
53
+ return *loss_and_grad, hess
54
+
55
+ else:
56
+
57
+ @jax.jit
58
+ def loss_fn_fused(x):
59
+ return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
60
+
61
+ if use_hessp:
62
+
63
+ def f_hessp_jax(x, p):
64
+ y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,))
65
+ return jax.numpy.stack(u)
66
+
67
+ f_hessp = jax.jit(f_hessp_jax)
68
+
69
+ return loss_fn_fused, f_hessp
70
+
71
+
72
+ def _compile_functions_for_scipy_optimize(
73
+ loss: TensorVariable,
74
+ inputs: list[TensorVariable],
75
+ compute_grad: bool,
76
+ compute_hess: bool,
77
+ compute_hessp: bool,
78
+ compile_kwargs: dict | None = None,
79
+ ) -> list[Function] | list[Function, Function | None, Function | None]:
80
+ """
81
+ Compile loss functions for use with scipy.optimize.minimize.
82
+
83
+ Parameters
84
+ ----------
85
+ loss: TensorVariable
86
+ The loss function to compile.
87
+ inputs: list[TensorVariable]
88
+ A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
89
+ expect the function signature to be f(x, *args), where x is a 1D array of parameters.
90
+ compute_grad: bool
91
+ Whether to compile a function that computes the gradients of the loss function.
92
+ compute_hess: bool
93
+ Whether to compile a function that computes the Hessian of the loss function.
94
+ compute_hessp: bool
95
+ Whether to compile a function that computes the Hessian-vector product of the loss function.
96
+ compile_kwargs: dict, optional
97
+ Additional keyword arguments to pass to the ``pm.compile`` function.
98
+
99
+ Returns
100
+ -------
101
+ f_fused: Function
102
+ The compiled loss function, which may also include gradients and hessian if requested.
103
+ f_hessp: Function | None
104
+ The compiled hessian-vector product function, or None if compute_hessp is False.
105
+ """
106
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
107
+
108
+ loss = pm.pytensorf.rewrite_pregrad(loss)
109
+ f_hessp = None
110
+
111
+ # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent
112
+ # with the case where we also compute gradients, hessians, or hessian-vector products.
113
+ if not (compute_grad or compute_hess or compute_hessp):
114
+ f_loss = pm.compile(inputs, loss, **compile_kwargs)
115
+ return [f_loss]
116
+
117
+ # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single
118
+ # fused function and return it. If the user also wants the hessian, the fused function will return the loss,
119
+ # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss
120
+ # and gradients, and a separate function for the hessian-vector product.
121
+
122
+ if compute_hessp:
123
+ # Handle this first, since it can be compiled alone.
124
+ p = pt.tensor("p", shape=inputs[0].type.shape)
125
+ hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
126
+ f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
127
+
128
+ outputs = [loss]
129
+
130
+ if compute_grad:
131
+ grads = pytensor.gradient.grad(loss, inputs)
132
+ grad = pt.concatenate([grad.ravel() for grad in grads])
133
+ outputs.append(grad)
134
+
135
+ if compute_hess:
136
+ hess = pytensor.gradient.jacobian(grad, inputs)[0]
137
+ outputs.append(hess)
138
+
139
+ f_fused = pm.compile(inputs, outputs, **compile_kwargs)
140
+
141
+ return [f_fused, f_hessp]
142
+
143
+
144
+ def scipy_optimize_funcs_from_loss(
145
+ loss: TensorVariable,
146
+ inputs: list[TensorVariable],
147
+ initial_point_dict: dict[str, np.ndarray | float | int],
148
+ use_grad: bool,
149
+ use_hess: bool,
150
+ use_hessp: bool,
151
+ gradient_backend: GradientBackend = "pytensor",
152
+ compile_kwargs: dict | None = None,
153
+ ) -> tuple[Callable, ...]:
154
+ """
155
+ Compile loss functions for use with scipy.optimize.minimize.
156
+
157
+ Parameters
158
+ ----------
159
+ loss: TensorVariable
160
+ The loss function to compile.
161
+ inputs: list[TensorVariable]
162
+ The input variables to the loss function.
163
+ initial_point_dict: dict[str, np.ndarray | float | int]
164
+ Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
165
+ use_grad: bool
166
+ Whether to compile a function that computes the gradients of the loss function.
167
+ use_hess: bool
168
+ Whether to compile a function that computes the Hessian of the loss function.
169
+ use_hessp: bool
170
+ Whether to compile a function that computes the Hessian-vector product of the loss function.
171
+ gradient_backend: str, default "pytensor"
172
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
173
+ compile_kwargs:
174
+ Additional keyword arguments to pass to the ``pm.compile`` function.
175
+
176
+ Returns
177
+ -------
178
+ f_fused: Callable
179
+ The compiled loss function, which may also include gradients and hessian if requested.
180
+ f_hessp: Callable | None
181
+ The compiled hessian-vector product function, or None if use_hessp is False.
182
+ """
183
+
184
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
185
+
186
+ if use_hess and not use_grad:
187
+ raise ValueError("Cannot compute hessian without also computing the gradient")
188
+
189
+ if gradient_backend not in VALID_BACKENDS:
190
+ raise ValueError(
191
+ f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
192
+ )
193
+
194
+ use_jax_gradients = (gradient_backend == "jax") and use_grad
195
+ if use_jax_gradients and not find_spec("jax"):
196
+ raise ImportError("JAX must be installed to use JAX gradients")
197
+
198
+ mode = compile_kwargs.get("mode", None)
199
+ if mode is None and use_jax_gradients:
200
+ compile_kwargs["mode"] = "JAX"
201
+ elif mode != "JAX" and use_jax_gradients:
202
+ raise ValueError(
203
+ 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
204
+ )
205
+
206
+ if not isinstance(inputs, list):
207
+ inputs = [inputs]
208
+
209
+ [loss], flat_input = join_nonshared_inputs(
210
+ point=initial_point_dict, outputs=[loss], inputs=inputs
211
+ )
212
+
213
+ # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
214
+ # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
215
+ # away.
216
+ if use_jax_gradients:
217
+ from pymc.sampling.jax import _replace_shared_variables
218
+
219
+ [loss] = _replace_shared_variables([loss])
220
+
221
+ compute_grad = use_grad and not use_jax_gradients
222
+ compute_hess = use_hess and not use_jax_gradients
223
+ compute_hessp = use_hessp and not use_jax_gradients
224
+
225
+ funcs = _compile_functions_for_scipy_optimize(
226
+ loss=loss,
227
+ inputs=[flat_input],
228
+ compute_grad=compute_grad,
229
+ compute_hess=compute_hess,
230
+ compute_hessp=compute_hessp,
231
+ compile_kwargs=compile_kwargs,
232
+ )
233
+
234
+ # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients,
235
+ # or the loss function with gradients and hessian.
236
+ f_fused = funcs.pop(0)
237
+ f_hessp = funcs.pop(0) if compute_hessp else None
238
+
239
+ if use_jax_gradients:
240
+ f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp)
241
+
242
+ return f_fused, f_hessp