pymc-extras 0.2.7__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/inference/__init__.py +2 -2
- pymc_extras/inference/fit.py +1 -1
- pymc_extras/inference/laplace_approx/__init__.py +0 -0
- pymc_extras/inference/laplace_approx/find_map.py +354 -0
- pymc_extras/inference/laplace_approx/idata.py +393 -0
- pymc_extras/inference/laplace_approx/laplace.py +453 -0
- pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
- pymc_extras/inference/pathfinder/pathfinder.py +3 -4
- pymc_extras/linearmodel.py +3 -1
- pymc_extras/model/marginal/graph_analysis.py +4 -0
- pymc_extras/prior.py +38 -6
- pymc_extras/statespace/core/statespace.py +78 -52
- pymc_extras/statespace/filters/kalman_smoother.py +1 -1
- pymc_extras/statespace/models/structural/__init__.py +21 -0
- pymc_extras/statespace/models/structural/components/__init__.py +0 -0
- pymc_extras/statespace/models/structural/components/autoregressive.py +188 -0
- pymc_extras/statespace/models/structural/components/cycle.py +305 -0
- pymc_extras/statespace/models/structural/components/level_trend.py +257 -0
- pymc_extras/statespace/models/structural/components/measurement_error.py +137 -0
- pymc_extras/statespace/models/structural/components/regression.py +228 -0
- pymc_extras/statespace/models/structural/components/seasonality.py +445 -0
- pymc_extras/statespace/models/structural/core.py +900 -0
- pymc_extras/statespace/models/structural/utils.py +16 -0
- pymc_extras/statespace/models/utilities.py +285 -0
- pymc_extras/statespace/utils/constants.py +4 -4
- pymc_extras/statespace/utils/data_tools.py +3 -2
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/METADATA +6 -6
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/RECORD +30 -18
- pymc_extras/inference/find_map.py +0 -496
- pymc_extras/inference/laplace.py +0 -583
- pymc_extras/statespace/models/structural.py +0 -1679
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/WHEEL +0 -0
- {pymc_extras-0.2.7.dist-info → pymc_extras-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,496 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
|
|
3
|
-
from collections.abc import Callable
|
|
4
|
-
from importlib.util import find_spec
|
|
5
|
-
from typing import Literal, cast, get_args
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import pymc as pm
|
|
9
|
-
import pytensor
|
|
10
|
-
import pytensor.tensor as pt
|
|
11
|
-
|
|
12
|
-
from better_optimize import basinhopping, minimize
|
|
13
|
-
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
|
|
14
|
-
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
15
|
-
from pymc.initial_point import make_initial_point_fn
|
|
16
|
-
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
17
|
-
from pymc.pytensorf import join_nonshared_inputs
|
|
18
|
-
from pymc.util import get_default_varnames
|
|
19
|
-
from pytensor.compile import Function
|
|
20
|
-
from pytensor.compile.mode import Mode
|
|
21
|
-
from pytensor.tensor import TensorVariable
|
|
22
|
-
from scipy.optimize import OptimizeResult
|
|
23
|
-
|
|
24
|
-
_log = logging.getLogger(__name__)
|
|
25
|
-
|
|
26
|
-
GradientBackend = Literal["pytensor", "jax"]
|
|
27
|
-
VALID_BACKENDS = get_args(GradientBackend)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
31
|
-
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
32
|
-
|
|
33
|
-
if use_hess and use_hessp:
|
|
34
|
-
_log.warning(
|
|
35
|
-
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
36
|
-
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
37
|
-
'Setting "use_hess" to False.'
|
|
38
|
-
)
|
|
39
|
-
use_hess = False
|
|
40
|
-
|
|
41
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
42
|
-
|
|
43
|
-
if use_hessp is not None and use_hess is None:
|
|
44
|
-
use_hess = not use_hessp
|
|
45
|
-
|
|
46
|
-
elif use_hess is not None and use_hessp is None:
|
|
47
|
-
use_hessp = not use_hess
|
|
48
|
-
|
|
49
|
-
elif use_hessp is None and use_hess is None:
|
|
50
|
-
use_hessp = method_info["uses_hessp"]
|
|
51
|
-
use_hess = method_info["uses_hess"]
|
|
52
|
-
if use_hessp and use_hess:
|
|
53
|
-
# If a method could use either hess or hessp, we default to using hessp
|
|
54
|
-
use_hess = False
|
|
55
|
-
|
|
56
|
-
return use_grad, use_hess, use_hessp
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
60
|
-
"""
|
|
61
|
-
Compute the nearest positive semi-definite matrix to a given matrix.
|
|
62
|
-
|
|
63
|
-
This function takes a square matrix and returns the nearest positive semi-definite matrix using
|
|
64
|
-
eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
|
|
65
|
-
of the Frobenius norm.
|
|
66
|
-
|
|
67
|
-
Parameters
|
|
68
|
-
----------
|
|
69
|
-
A : np.ndarray
|
|
70
|
-
Input square matrix.
|
|
71
|
-
|
|
72
|
-
Returns
|
|
73
|
-
-------
|
|
74
|
-
np.ndarray
|
|
75
|
-
The nearest positive semi-definite matrix to the input matrix.
|
|
76
|
-
"""
|
|
77
|
-
C = (A + A.T) / 2
|
|
78
|
-
eigval, eigvec = np.linalg.eigh(C)
|
|
79
|
-
eigval[eigval < 0] = 0
|
|
80
|
-
|
|
81
|
-
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def _unconstrained_vector_to_constrained_rvs(model):
|
|
85
|
-
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
|
|
86
|
-
model.initial_point(),
|
|
87
|
-
inputs=model.value_vars,
|
|
88
|
-
outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
|
|
89
|
-
)
|
|
90
|
-
|
|
91
|
-
unconstrained_vector.name = "unconstrained_vector"
|
|
92
|
-
return constrained_rvs, unconstrained_vector
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws):
|
|
96
|
-
X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0]))
|
|
97
|
-
out = []
|
|
98
|
-
for rv, idx in slices.items():
|
|
99
|
-
f = model.rvs_to_transforms[rv]
|
|
100
|
-
untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx]
|
|
101
|
-
|
|
102
|
-
if rv in out_shapes:
|
|
103
|
-
new_shape = (chains, draws) + out_shapes[rv]
|
|
104
|
-
untransformed_X = untransformed_X.reshape(new_shape)
|
|
105
|
-
|
|
106
|
-
out.append(untransformed_X)
|
|
107
|
-
|
|
108
|
-
f_untransform = pytensor.function(
|
|
109
|
-
inputs=[pytensor.In(X, borrow=True)],
|
|
110
|
-
outputs=pytensor.Out(out, borrow=True),
|
|
111
|
-
mode=Mode(linker="py", optimizer="FAST_COMPILE"),
|
|
112
|
-
)
|
|
113
|
-
return f_untransform(posterior_draws)
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def _compile_grad_and_hess_to_jax(
|
|
117
|
-
f_loss: Function, use_hess: bool, use_hessp: bool
|
|
118
|
-
) -> tuple[Callable | None, Callable | None]:
|
|
119
|
-
"""
|
|
120
|
-
Compile loss function gradients using JAX.
|
|
121
|
-
|
|
122
|
-
Parameters
|
|
123
|
-
----------
|
|
124
|
-
f_loss: Function
|
|
125
|
-
The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
|
|
126
|
-
compiled with mode="JAX".
|
|
127
|
-
use_hess: bool
|
|
128
|
-
Whether to compile a function to compute the hessian of the loss function.
|
|
129
|
-
use_hessp: bool
|
|
130
|
-
Whether to compile a function to compute the hessian-vector product of the loss function.
|
|
131
|
-
|
|
132
|
-
Returns
|
|
133
|
-
-------
|
|
134
|
-
f_loss_and_grad: Callable
|
|
135
|
-
The compiled loss function and gradient function.
|
|
136
|
-
f_hess: Callable | None
|
|
137
|
-
The compiled hessian function, or None if use_hess is False.
|
|
138
|
-
f_hessp: Callable | None
|
|
139
|
-
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
140
|
-
"""
|
|
141
|
-
import jax
|
|
142
|
-
|
|
143
|
-
f_hess = None
|
|
144
|
-
f_hessp = None
|
|
145
|
-
|
|
146
|
-
orig_loss_fn = f_loss.vm.jit_fn
|
|
147
|
-
|
|
148
|
-
@jax.jit
|
|
149
|
-
def loss_fn_jax_grad(x):
|
|
150
|
-
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
|
|
151
|
-
|
|
152
|
-
f_loss_and_grad = loss_fn_jax_grad
|
|
153
|
-
|
|
154
|
-
if use_hessp:
|
|
155
|
-
|
|
156
|
-
def f_hessp_jax(x, p):
|
|
157
|
-
y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,))
|
|
158
|
-
return jax.numpy.stack(u)
|
|
159
|
-
|
|
160
|
-
f_hessp = jax.jit(f_hessp_jax)
|
|
161
|
-
|
|
162
|
-
if use_hess:
|
|
163
|
-
_f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1])
|
|
164
|
-
|
|
165
|
-
def f_hess_jax(x):
|
|
166
|
-
return jax.numpy.stack(_f_hess_jax(x))
|
|
167
|
-
|
|
168
|
-
f_hess = jax.jit(f_hess_jax)
|
|
169
|
-
|
|
170
|
-
return f_loss_and_grad, f_hess, f_hessp
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def _compile_functions_for_scipy_optimize(
|
|
174
|
-
loss: TensorVariable,
|
|
175
|
-
inputs: list[TensorVariable],
|
|
176
|
-
compute_grad: bool,
|
|
177
|
-
compute_hess: bool,
|
|
178
|
-
compute_hessp: bool,
|
|
179
|
-
compile_kwargs: dict | None = None,
|
|
180
|
-
) -> list[Function] | list[Function, Function | None, Function | None]:
|
|
181
|
-
"""
|
|
182
|
-
Compile loss functions for use with scipy.optimize.minimize.
|
|
183
|
-
|
|
184
|
-
Parameters
|
|
185
|
-
----------
|
|
186
|
-
loss: TensorVariable
|
|
187
|
-
The loss function to compile.
|
|
188
|
-
inputs: list[TensorVariable]
|
|
189
|
-
A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
|
|
190
|
-
expect the function signature to be f(x, *args), where x is a 1D array of parameters.
|
|
191
|
-
compute_grad: bool
|
|
192
|
-
Whether to compile a function that computes the gradients of the loss function.
|
|
193
|
-
compute_hess: bool
|
|
194
|
-
Whether to compile a function that computes the Hessian of the loss function.
|
|
195
|
-
compute_hessp: bool
|
|
196
|
-
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
197
|
-
compile_kwargs: dict, optional
|
|
198
|
-
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
199
|
-
|
|
200
|
-
Returns
|
|
201
|
-
-------
|
|
202
|
-
f_loss: Function
|
|
203
|
-
|
|
204
|
-
f_hess: Function | None
|
|
205
|
-
f_hessp: Function | None
|
|
206
|
-
"""
|
|
207
|
-
loss = pm.pytensorf.rewrite_pregrad(loss)
|
|
208
|
-
f_hess = None
|
|
209
|
-
f_hessp = None
|
|
210
|
-
|
|
211
|
-
if compute_grad:
|
|
212
|
-
grads = pytensor.gradient.grad(loss, inputs)
|
|
213
|
-
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
214
|
-
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
|
|
215
|
-
else:
|
|
216
|
-
f_loss = pm.compile(inputs, loss, **compile_kwargs)
|
|
217
|
-
return [f_loss]
|
|
218
|
-
|
|
219
|
-
if compute_hess:
|
|
220
|
-
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
221
|
-
f_hess = pm.compile(inputs, hess, **compile_kwargs)
|
|
222
|
-
|
|
223
|
-
if compute_hessp:
|
|
224
|
-
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
225
|
-
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
226
|
-
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
|
|
227
|
-
|
|
228
|
-
return [f_loss_and_grad, f_hess, f_hessp]
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def scipy_optimize_funcs_from_loss(
|
|
232
|
-
loss: TensorVariable,
|
|
233
|
-
inputs: list[TensorVariable],
|
|
234
|
-
initial_point_dict: dict[str, np.ndarray | float | int],
|
|
235
|
-
use_grad: bool,
|
|
236
|
-
use_hess: bool,
|
|
237
|
-
use_hessp: bool,
|
|
238
|
-
gradient_backend: GradientBackend = "pytensor",
|
|
239
|
-
compile_kwargs: dict | None = None,
|
|
240
|
-
) -> tuple[Callable, ...]:
|
|
241
|
-
"""
|
|
242
|
-
Compile loss functions for use with scipy.optimize.minimize.
|
|
243
|
-
|
|
244
|
-
Parameters
|
|
245
|
-
----------
|
|
246
|
-
loss: TensorVariable
|
|
247
|
-
The loss function to compile.
|
|
248
|
-
inputs: list[TensorVariable]
|
|
249
|
-
The input variables to the loss function.
|
|
250
|
-
initial_point_dict: dict[str, np.ndarray | float | int]
|
|
251
|
-
Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
|
|
252
|
-
use_grad: bool
|
|
253
|
-
Whether to compile a function that computes the gradients of the loss function.
|
|
254
|
-
use_hess: bool
|
|
255
|
-
Whether to compile a function that computes the Hessian of the loss function.
|
|
256
|
-
use_hessp: bool
|
|
257
|
-
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
258
|
-
gradient_backend: str, default "pytensor"
|
|
259
|
-
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
260
|
-
compile_kwargs:
|
|
261
|
-
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
262
|
-
|
|
263
|
-
Returns
|
|
264
|
-
-------
|
|
265
|
-
f_loss: Callable
|
|
266
|
-
The compiled loss function.
|
|
267
|
-
f_hess: Callable | None
|
|
268
|
-
The compiled hessian function, or None if use_hess is False.
|
|
269
|
-
f_hessp: Callable | None
|
|
270
|
-
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
271
|
-
"""
|
|
272
|
-
|
|
273
|
-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
274
|
-
|
|
275
|
-
if (use_hess or use_hessp) and not use_grad:
|
|
276
|
-
raise ValueError(
|
|
277
|
-
"Cannot compute hessian or hessian-vector product without also computing the gradient"
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
if gradient_backend not in VALID_BACKENDS:
|
|
281
|
-
raise ValueError(
|
|
282
|
-
f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
286
|
-
if use_jax_gradients and not find_spec("jax"):
|
|
287
|
-
raise ImportError("JAX must be installed to use JAX gradients")
|
|
288
|
-
|
|
289
|
-
mode = compile_kwargs.get("mode", None)
|
|
290
|
-
if mode is None and use_jax_gradients:
|
|
291
|
-
compile_kwargs["mode"] = "JAX"
|
|
292
|
-
elif mode != "JAX" and use_jax_gradients:
|
|
293
|
-
raise ValueError(
|
|
294
|
-
'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
if not isinstance(inputs, list):
|
|
298
|
-
inputs = [inputs]
|
|
299
|
-
|
|
300
|
-
[loss], flat_input = join_nonshared_inputs(
|
|
301
|
-
point=initial_point_dict, outputs=[loss], inputs=inputs
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
|
|
305
|
-
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
|
|
306
|
-
# away.
|
|
307
|
-
if use_jax_gradients:
|
|
308
|
-
from pymc.sampling.jax import _replace_shared_variables
|
|
309
|
-
|
|
310
|
-
[loss] = _replace_shared_variables([loss])
|
|
311
|
-
|
|
312
|
-
compute_grad = use_grad and not use_jax_gradients
|
|
313
|
-
compute_hess = use_hess and not use_jax_gradients
|
|
314
|
-
compute_hessp = use_hessp and not use_jax_gradients
|
|
315
|
-
|
|
316
|
-
funcs = _compile_functions_for_scipy_optimize(
|
|
317
|
-
loss=loss,
|
|
318
|
-
inputs=[flat_input],
|
|
319
|
-
compute_grad=compute_grad,
|
|
320
|
-
compute_hess=compute_hess,
|
|
321
|
-
compute_hessp=compute_hessp,
|
|
322
|
-
compile_kwargs=compile_kwargs,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
# f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
|
|
326
|
-
f_loss = funcs.pop(0)
|
|
327
|
-
f_hess = funcs.pop(0) if compute_grad else None
|
|
328
|
-
f_hessp = funcs.pop(0) if compute_grad else None
|
|
329
|
-
|
|
330
|
-
if use_jax_gradients:
|
|
331
|
-
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
|
|
332
|
-
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
|
|
333
|
-
|
|
334
|
-
return f_loss, f_hess, f_hessp
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
def find_MAP(
|
|
338
|
-
method: minimize_method | Literal["basinhopping"],
|
|
339
|
-
*,
|
|
340
|
-
model: pm.Model | None = None,
|
|
341
|
-
use_grad: bool | None = None,
|
|
342
|
-
use_hessp: bool | None = None,
|
|
343
|
-
use_hess: bool | None = None,
|
|
344
|
-
initvals: dict | None = None,
|
|
345
|
-
random_seed: int | np.random.Generator | None = None,
|
|
346
|
-
return_raw: bool = False,
|
|
347
|
-
jitter_rvs: list[TensorVariable] | None = None,
|
|
348
|
-
progressbar: bool = True,
|
|
349
|
-
include_transformed: bool = True,
|
|
350
|
-
gradient_backend: GradientBackend = "pytensor",
|
|
351
|
-
compile_kwargs: dict | None = None,
|
|
352
|
-
**optimizer_kwargs,
|
|
353
|
-
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
|
|
354
|
-
"""
|
|
355
|
-
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
|
|
356
|
-
|
|
357
|
-
Parameters
|
|
358
|
-
----------
|
|
359
|
-
model : pm.Model
|
|
360
|
-
The PyMC model to be fit. If None, the current model context is used.
|
|
361
|
-
method : str
|
|
362
|
-
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
363
|
-
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
364
|
-
|
|
365
|
-
See scipy.optimize.minimize documentation for details.
|
|
366
|
-
use_grad : bool | None, optional
|
|
367
|
-
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
368
|
-
the ``method``.
|
|
369
|
-
use_hessp : bool | None, optional
|
|
370
|
-
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
|
|
371
|
-
the ``method``.
|
|
372
|
-
use_hess : bool | None, optional
|
|
373
|
-
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
|
|
374
|
-
the ``method``.
|
|
375
|
-
initvals : None | dict, optional
|
|
376
|
-
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
377
|
-
If None, the model's default initial values are used.
|
|
378
|
-
random_seed : None | int | np.random.Generator, optional
|
|
379
|
-
Seed for the random number generator or a numpy Generator for reproducibility
|
|
380
|
-
return_raw: bool | False, optinal
|
|
381
|
-
Whether to also return the full output of `scipy.optimize.minimize`
|
|
382
|
-
jitter_rvs : list of TensorVariables, optional
|
|
383
|
-
Variables whose initial values should be jittered. If None, all variables are jittered.
|
|
384
|
-
progressbar : bool, optional
|
|
385
|
-
Whether to display a progress bar during optimization. Defaults to True.
|
|
386
|
-
include_transformed: bool, optional
|
|
387
|
-
Whether to include transformed variable values in the returned dictionary. Defaults to True.
|
|
388
|
-
gradient_backend: str, default "pytensor"
|
|
389
|
-
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
|
|
390
|
-
compile_kwargs: dict, optional
|
|
391
|
-
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
|
|
392
|
-
**optimizer_kwargs
|
|
393
|
-
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
394
|
-
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
395
|
-
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
396
|
-
|
|
397
|
-
Returns
|
|
398
|
-
-------
|
|
399
|
-
optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult]
|
|
400
|
-
Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True,
|
|
401
|
-
also returns the object returned by ``scipy.optimize.minimize``.
|
|
402
|
-
"""
|
|
403
|
-
model = pm.modelcontext(model)
|
|
404
|
-
frozen_model = freeze_dims_and_data(model)
|
|
405
|
-
|
|
406
|
-
jitter_rvs = [] if jitter_rvs is None else jitter_rvs
|
|
407
|
-
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
408
|
-
|
|
409
|
-
ipfn = make_initial_point_fn(
|
|
410
|
-
model=frozen_model,
|
|
411
|
-
jitter_rvs=set(jitter_rvs),
|
|
412
|
-
return_transformed=True,
|
|
413
|
-
overrides=initvals,
|
|
414
|
-
)
|
|
415
|
-
|
|
416
|
-
start_dict = ipfn(random_seed)
|
|
417
|
-
vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
|
|
418
|
-
initial_params = DictToArrayBijection.map(
|
|
419
|
-
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
|
|
420
|
-
)
|
|
421
|
-
|
|
422
|
-
do_basinhopping = method == "basinhopping"
|
|
423
|
-
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
424
|
-
|
|
425
|
-
if do_basinhopping:
|
|
426
|
-
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
427
|
-
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
428
|
-
# if one isn't provided.
|
|
429
|
-
|
|
430
|
-
method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
431
|
-
minimizer_kwargs["method"] = method
|
|
432
|
-
|
|
433
|
-
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
434
|
-
method, use_grad, use_hess, use_hessp
|
|
435
|
-
)
|
|
436
|
-
|
|
437
|
-
f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
|
|
438
|
-
loss=-frozen_model.logp(jacobian=False),
|
|
439
|
-
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
|
|
440
|
-
initial_point_dict=start_dict,
|
|
441
|
-
use_grad=use_grad,
|
|
442
|
-
use_hess=use_hess,
|
|
443
|
-
use_hessp=use_hessp,
|
|
444
|
-
gradient_backend=gradient_backend,
|
|
445
|
-
compile_kwargs=compile_kwargs,
|
|
446
|
-
)
|
|
447
|
-
|
|
448
|
-
args = optimizer_kwargs.pop("args", None)
|
|
449
|
-
|
|
450
|
-
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
|
|
451
|
-
# if so. That is why the jac argument is not passed here in either branch.
|
|
452
|
-
|
|
453
|
-
if do_basinhopping:
|
|
454
|
-
if "args" not in minimizer_kwargs:
|
|
455
|
-
minimizer_kwargs["args"] = args
|
|
456
|
-
if "hess" not in minimizer_kwargs:
|
|
457
|
-
minimizer_kwargs["hess"] = f_hess
|
|
458
|
-
if "hessp" not in minimizer_kwargs:
|
|
459
|
-
minimizer_kwargs["hessp"] = f_hessp
|
|
460
|
-
if "method" not in minimizer_kwargs:
|
|
461
|
-
minimizer_kwargs["method"] = method
|
|
462
|
-
|
|
463
|
-
optimizer_result = basinhopping(
|
|
464
|
-
func=f_logp,
|
|
465
|
-
x0=cast(np.ndarray[float], initial_params.data),
|
|
466
|
-
progressbar=progressbar,
|
|
467
|
-
minimizer_kwargs=minimizer_kwargs,
|
|
468
|
-
**optimizer_kwargs,
|
|
469
|
-
)
|
|
470
|
-
|
|
471
|
-
else:
|
|
472
|
-
optimizer_result = minimize(
|
|
473
|
-
f=f_logp,
|
|
474
|
-
x0=cast(np.ndarray[float], initial_params.data),
|
|
475
|
-
args=args,
|
|
476
|
-
hess=f_hess,
|
|
477
|
-
hessp=f_hessp,
|
|
478
|
-
progressbar=progressbar,
|
|
479
|
-
method=method,
|
|
480
|
-
**optimizer_kwargs,
|
|
481
|
-
)
|
|
482
|
-
|
|
483
|
-
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
484
|
-
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
|
|
485
|
-
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
|
|
486
|
-
DictToArrayBijection.rmap(raveled_optimized)
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
optimized_point = {
|
|
490
|
-
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
|
|
491
|
-
}
|
|
492
|
-
|
|
493
|
-
if return_raw:
|
|
494
|
-
return optimized_point, optimizer_result
|
|
495
|
-
|
|
496
|
-
return optimized_point
|