pymc-extras 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +29 -0
- pymc_extras/distributions/__init__.py +40 -0
- pymc_extras/distributions/continuous.py +351 -0
- pymc_extras/distributions/discrete.py +399 -0
- pymc_extras/distributions/histogram_utils.py +163 -0
- pymc_extras/distributions/multivariate/__init__.py +3 -0
- pymc_extras/distributions/multivariate/r2d2m2cp.py +446 -0
- pymc_extras/distributions/timeseries.py +356 -0
- pymc_extras/gp/__init__.py +18 -0
- pymc_extras/gp/latent_approx.py +183 -0
- pymc_extras/inference/__init__.py +18 -0
- pymc_extras/inference/find_map.py +431 -0
- pymc_extras/inference/fit.py +44 -0
- pymc_extras/inference/laplace.py +570 -0
- pymc_extras/inference/pathfinder.py +134 -0
- pymc_extras/inference/smc/__init__.py +13 -0
- pymc_extras/inference/smc/sampling.py +451 -0
- pymc_extras/linearmodel.py +130 -0
- pymc_extras/model/__init__.py +0 -0
- pymc_extras/model/marginal/__init__.py +0 -0
- pymc_extras/model/marginal/distributions.py +276 -0
- pymc_extras/model/marginal/graph_analysis.py +372 -0
- pymc_extras/model/marginal/marginal_model.py +595 -0
- pymc_extras/model/model_api.py +56 -0
- pymc_extras/model/transforms/__init__.py +0 -0
- pymc_extras/model/transforms/autoreparam.py +434 -0
- pymc_extras/model_builder.py +759 -0
- pymc_extras/preprocessing/__init__.py +0 -0
- pymc_extras/preprocessing/standard_scaler.py +17 -0
- pymc_extras/printing.py +182 -0
- pymc_extras/statespace/__init__.py +13 -0
- pymc_extras/statespace/core/__init__.py +7 -0
- pymc_extras/statespace/core/compile.py +48 -0
- pymc_extras/statespace/core/representation.py +438 -0
- pymc_extras/statespace/core/statespace.py +2268 -0
- pymc_extras/statespace/filters/__init__.py +15 -0
- pymc_extras/statespace/filters/distributions.py +453 -0
- pymc_extras/statespace/filters/kalman_filter.py +820 -0
- pymc_extras/statespace/filters/kalman_smoother.py +126 -0
- pymc_extras/statespace/filters/utilities.py +59 -0
- pymc_extras/statespace/models/ETS.py +670 -0
- pymc_extras/statespace/models/SARIMAX.py +536 -0
- pymc_extras/statespace/models/VARMAX.py +393 -0
- pymc_extras/statespace/models/__init__.py +6 -0
- pymc_extras/statespace/models/structural.py +1651 -0
- pymc_extras/statespace/models/utilities.py +387 -0
- pymc_extras/statespace/utils/__init__.py +0 -0
- pymc_extras/statespace/utils/constants.py +74 -0
- pymc_extras/statespace/utils/coord_tools.py +0 -0
- pymc_extras/statespace/utils/data_tools.py +182 -0
- pymc_extras/utils/__init__.py +23 -0
- pymc_extras/utils/linear_cg.py +290 -0
- pymc_extras/utils/pivoted_cholesky.py +69 -0
- pymc_extras/utils/prior.py +200 -0
- pymc_extras/utils/spline.py +131 -0
- pymc_extras/version.py +11 -0
- pymc_extras/version.txt +1 -0
- pymc_extras-0.2.0.dist-info/LICENSE +212 -0
- pymc_extras-0.2.0.dist-info/METADATA +99 -0
- pymc_extras-0.2.0.dist-info/RECORD +101 -0
- pymc_extras-0.2.0.dist-info/WHEEL +5 -0
- pymc_extras-0.2.0.dist-info/top_level.txt +2 -0
- tests/__init__.py +13 -0
- tests/distributions/__init__.py +19 -0
- tests/distributions/test_continuous.py +185 -0
- tests/distributions/test_discrete.py +210 -0
- tests/distributions/test_discrete_markov_chain.py +258 -0
- tests/distributions/test_multivariate.py +304 -0
- tests/model/__init__.py +0 -0
- tests/model/marginal/__init__.py +0 -0
- tests/model/marginal/test_distributions.py +131 -0
- tests/model/marginal/test_graph_analysis.py +182 -0
- tests/model/marginal/test_marginal_model.py +867 -0
- tests/model/test_model_api.py +29 -0
- tests/statespace/__init__.py +0 -0
- tests/statespace/test_ETS.py +411 -0
- tests/statespace/test_SARIMAX.py +405 -0
- tests/statespace/test_VARMAX.py +184 -0
- tests/statespace/test_coord_assignment.py +116 -0
- tests/statespace/test_distributions.py +270 -0
- tests/statespace/test_kalman_filter.py +326 -0
- tests/statespace/test_representation.py +175 -0
- tests/statespace/test_statespace.py +818 -0
- tests/statespace/test_statespace_JAX.py +156 -0
- tests/statespace/test_structural.py +829 -0
- tests/statespace/utilities/__init__.py +0 -0
- tests/statespace/utilities/shared_fixtures.py +9 -0
- tests/statespace/utilities/statsmodel_local_level.py +42 -0
- tests/statespace/utilities/test_helpers.py +310 -0
- tests/test_blackjax_smc.py +222 -0
- tests/test_find_map.py +98 -0
- tests/test_histogram_approximation.py +109 -0
- tests/test_laplace.py +238 -0
- tests/test_linearmodel.py +208 -0
- tests/test_model_builder.py +306 -0
- tests/test_pathfinder.py +45 -0
- tests/test_pivoted_cholesky.py +24 -0
- tests/test_printing.py +98 -0
- tests/test_prior_from_trace.py +172 -0
- tests/test_splines.py +77 -0
- tests/utils.py +31 -0
|
@@ -0,0 +1,431 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Literal, cast, get_args
|
|
5
|
+
|
|
6
|
+
import jax
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pymc as pm
|
|
9
|
+
import pytensor
|
|
10
|
+
import pytensor.tensor as pt
|
|
11
|
+
|
|
12
|
+
from better_optimize import minimize
|
|
13
|
+
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
|
|
14
|
+
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
15
|
+
from pymc.initial_point import make_initial_point_fn
|
|
16
|
+
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
17
|
+
from pymc.pytensorf import join_nonshared_inputs
|
|
18
|
+
from pymc.util import get_default_varnames
|
|
19
|
+
from pytensor.compile import Function
|
|
20
|
+
from pytensor.compile.mode import Mode
|
|
21
|
+
from pytensor.tensor import TensorVariable
|
|
22
|
+
from scipy.optimize import OptimizeResult
|
|
23
|
+
|
|
24
|
+
_log = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
GradientBackend = Literal["pytensor", "jax"]
|
|
27
|
+
VALID_BACKENDS = get_args(GradientBackend)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
31
|
+
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
32
|
+
|
|
33
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
34
|
+
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
|
|
35
|
+
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
|
|
36
|
+
|
|
37
|
+
if use_hess and use_hessp:
|
|
38
|
+
use_hess = False
|
|
39
|
+
|
|
40
|
+
return use_grad, use_hess, use_hessp
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
44
|
+
"""
|
|
45
|
+
Compute the nearest positive semi-definite matrix to a given matrix.
|
|
46
|
+
|
|
47
|
+
This function takes a square matrix and returns the nearest positive semi-definite matrix using
|
|
48
|
+
eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
|
|
49
|
+
of the Frobenius norm.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
A : np.ndarray
|
|
54
|
+
Input square matrix.
|
|
55
|
+
|
|
56
|
+
Returns
|
|
57
|
+
-------
|
|
58
|
+
np.ndarray
|
|
59
|
+
The nearest positive semi-definite matrix to the input matrix.
|
|
60
|
+
"""
|
|
61
|
+
C = (A + A.T) / 2
|
|
62
|
+
eigval, eigvec = np.linalg.eig(C)
|
|
63
|
+
eigval[eigval < 0] = 0
|
|
64
|
+
|
|
65
|
+
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _unconstrained_vector_to_constrained_rvs(model):
|
|
69
|
+
constrained_rvs, unconstrained_vector = join_nonshared_inputs(
|
|
70
|
+
model.initial_point(),
|
|
71
|
+
inputs=model.value_vars,
|
|
72
|
+
outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
unconstrained_vector.name = "unconstrained_vector"
|
|
76
|
+
return constrained_rvs, unconstrained_vector
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws):
|
|
80
|
+
X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0]))
|
|
81
|
+
out = []
|
|
82
|
+
for rv, idx in slices.items():
|
|
83
|
+
f = model.rvs_to_transforms[rv]
|
|
84
|
+
untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx]
|
|
85
|
+
|
|
86
|
+
if rv in out_shapes:
|
|
87
|
+
new_shape = (chains, draws) + out_shapes[rv]
|
|
88
|
+
untransformed_X = untransformed_X.reshape(new_shape)
|
|
89
|
+
|
|
90
|
+
out.append(untransformed_X)
|
|
91
|
+
|
|
92
|
+
f_untransform = pytensor.function(
|
|
93
|
+
inputs=[pytensor.In(X, borrow=True)],
|
|
94
|
+
outputs=pytensor.Out(out, borrow=True),
|
|
95
|
+
mode=Mode(linker="py", optimizer="FAST_COMPILE"),
|
|
96
|
+
)
|
|
97
|
+
return f_untransform(posterior_draws)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _compile_jax_gradients(
|
|
101
|
+
f_loss: Function, use_hess: bool, use_hessp: bool
|
|
102
|
+
) -> tuple[Callable | None, Callable | None]:
|
|
103
|
+
"""
|
|
104
|
+
Compile loss function gradients using JAX.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
f_loss: Function
|
|
109
|
+
The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
|
|
110
|
+
compiled with mode="JAX".
|
|
111
|
+
use_hess: bool
|
|
112
|
+
Whether to compile a function to compute the hessian of the loss function.
|
|
113
|
+
use_hessp: bool
|
|
114
|
+
Whether to compile a function to compute the hessian-vector product of the loss function.
|
|
115
|
+
|
|
116
|
+
Returns
|
|
117
|
+
-------
|
|
118
|
+
f_loss_and_grad: Callable
|
|
119
|
+
The compiled loss function and gradient function.
|
|
120
|
+
f_hess: Callable | None
|
|
121
|
+
The compiled hessian function, or None if use_hess is False.
|
|
122
|
+
f_hessp: Callable | None
|
|
123
|
+
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
124
|
+
"""
|
|
125
|
+
f_hess = None
|
|
126
|
+
f_hessp = None
|
|
127
|
+
|
|
128
|
+
orig_loss_fn = f_loss.vm.jit_fn
|
|
129
|
+
|
|
130
|
+
@jax.jit
|
|
131
|
+
def loss_fn_jax_grad(x, *shared):
|
|
132
|
+
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
|
|
133
|
+
|
|
134
|
+
f_loss_and_grad = loss_fn_jax_grad
|
|
135
|
+
|
|
136
|
+
if use_hessp:
|
|
137
|
+
|
|
138
|
+
def f_hessp_jax(x, p):
|
|
139
|
+
y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,))
|
|
140
|
+
return jax.numpy.stack(u)
|
|
141
|
+
|
|
142
|
+
f_hessp = jax.jit(f_hessp_jax)
|
|
143
|
+
|
|
144
|
+
if use_hess:
|
|
145
|
+
_f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1])
|
|
146
|
+
|
|
147
|
+
def f_hess_jax(x):
|
|
148
|
+
return jax.numpy.stack(_f_hess_jax(x))
|
|
149
|
+
|
|
150
|
+
f_hess = jax.jit(f_hess_jax)
|
|
151
|
+
|
|
152
|
+
return f_loss_and_grad, f_hess, f_hessp
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _compile_functions(
|
|
156
|
+
loss: TensorVariable,
|
|
157
|
+
inputs: list[TensorVariable],
|
|
158
|
+
compute_grad: bool,
|
|
159
|
+
compute_hess: bool,
|
|
160
|
+
compute_hessp: bool,
|
|
161
|
+
compile_kwargs: dict | None = None,
|
|
162
|
+
) -> list[Function] | list[Function, Function | None, Function | None]:
|
|
163
|
+
"""
|
|
164
|
+
Compile loss functions for use with scipy.optimize.minimize.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
loss: TensorVariable
|
|
169
|
+
The loss function to compile.
|
|
170
|
+
inputs: list[TensorVariable]
|
|
171
|
+
A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines
|
|
172
|
+
expect the function signature to be f(x, *args), where x is a 1D array of parameters.
|
|
173
|
+
compute_grad: bool
|
|
174
|
+
Whether to compile a function that computes the gradients of the loss function.
|
|
175
|
+
compute_hess: bool
|
|
176
|
+
Whether to compile a function that computes the Hessian of the loss function.
|
|
177
|
+
compute_hessp: bool
|
|
178
|
+
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
179
|
+
compile_kwargs: dict, optional
|
|
180
|
+
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
|
|
181
|
+
|
|
182
|
+
Returns
|
|
183
|
+
-------
|
|
184
|
+
f_loss: Function
|
|
185
|
+
|
|
186
|
+
f_hess: Function | None
|
|
187
|
+
f_hessp: Function | None
|
|
188
|
+
"""
|
|
189
|
+
loss = pm.pytensorf.rewrite_pregrad(loss)
|
|
190
|
+
f_hess = None
|
|
191
|
+
f_hessp = None
|
|
192
|
+
|
|
193
|
+
if compute_grad:
|
|
194
|
+
grads = pytensor.gradient.grad(loss, inputs)
|
|
195
|
+
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
196
|
+
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
|
|
197
|
+
else:
|
|
198
|
+
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
|
|
199
|
+
return [f_loss]
|
|
200
|
+
|
|
201
|
+
if compute_hess:
|
|
202
|
+
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
203
|
+
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
|
|
204
|
+
|
|
205
|
+
if compute_hessp:
|
|
206
|
+
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
207
|
+
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
208
|
+
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
|
|
209
|
+
|
|
210
|
+
return [f_loss_and_grad, f_hess, f_hessp]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def scipy_optimize_funcs_from_loss(
|
|
214
|
+
loss: TensorVariable,
|
|
215
|
+
inputs: list[TensorVariable],
|
|
216
|
+
initial_point_dict: dict[str, np.ndarray | float | int],
|
|
217
|
+
use_grad: bool,
|
|
218
|
+
use_hess: bool,
|
|
219
|
+
use_hessp: bool,
|
|
220
|
+
gradient_backend: GradientBackend = "pytensor",
|
|
221
|
+
compile_kwargs: dict | None = None,
|
|
222
|
+
) -> tuple[Callable, ...]:
|
|
223
|
+
"""
|
|
224
|
+
Compile loss functions for use with scipy.optimize.minimize.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
loss: TensorVariable
|
|
229
|
+
The loss function to compile.
|
|
230
|
+
inputs: list[TensorVariable]
|
|
231
|
+
The input variables to the loss function.
|
|
232
|
+
initial_point_dict: dict[str, np.ndarray | float | int]
|
|
233
|
+
Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables.
|
|
234
|
+
use_grad: bool
|
|
235
|
+
Whether to compile a function that computes the gradients of the loss function.
|
|
236
|
+
use_hess: bool
|
|
237
|
+
Whether to compile a function that computes the Hessian of the loss function.
|
|
238
|
+
use_hessp: bool
|
|
239
|
+
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
240
|
+
gradient_backend: str, default "pytensor"
|
|
241
|
+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
242
|
+
compile_kwargs:
|
|
243
|
+
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
|
|
244
|
+
|
|
245
|
+
Returns
|
|
246
|
+
-------
|
|
247
|
+
f_loss: Callable
|
|
248
|
+
The compiled loss function.
|
|
249
|
+
f_hess: Callable | None
|
|
250
|
+
The compiled hessian function, or None if use_hess is False.
|
|
251
|
+
f_hessp: Callable | None
|
|
252
|
+
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
256
|
+
|
|
257
|
+
if (use_hess or use_hessp) and not use_grad:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
"Cannot compute hessian or hessian-vector product without also computing the gradient"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if gradient_backend not in VALID_BACKENDS:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
268
|
+
|
|
269
|
+
mode = compile_kwargs.get("mode", None)
|
|
270
|
+
if mode is None and use_jax_gradients:
|
|
271
|
+
compile_kwargs["mode"] = "JAX"
|
|
272
|
+
elif mode != "JAX" and use_jax_gradients:
|
|
273
|
+
raise ValueError(
|
|
274
|
+
'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"'
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
if not isinstance(inputs, list):
|
|
278
|
+
inputs = [inputs]
|
|
279
|
+
|
|
280
|
+
[loss], flat_input = join_nonshared_inputs(
|
|
281
|
+
point=initial_point_dict, outputs=[loss], inputs=inputs
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
compute_grad = use_grad and not use_jax_gradients
|
|
285
|
+
compute_hess = use_hess and not use_jax_gradients
|
|
286
|
+
compute_hessp = use_hessp and not use_jax_gradients
|
|
287
|
+
|
|
288
|
+
funcs = _compile_functions(
|
|
289
|
+
loss=loss,
|
|
290
|
+
inputs=[flat_input],
|
|
291
|
+
compute_grad=compute_grad,
|
|
292
|
+
compute_hess=compute_hess,
|
|
293
|
+
compute_hessp=compute_hessp,
|
|
294
|
+
compile_kwargs=compile_kwargs,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
|
|
298
|
+
f_loss = funcs.pop(0)
|
|
299
|
+
f_hess = funcs.pop(0) if compute_grad else None
|
|
300
|
+
f_hessp = funcs.pop(0) if compute_grad else None
|
|
301
|
+
|
|
302
|
+
if use_jax_gradients:
|
|
303
|
+
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
|
|
304
|
+
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
|
|
305
|
+
|
|
306
|
+
return f_loss, f_hess, f_hessp
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def find_MAP(
|
|
310
|
+
method: minimize_method,
|
|
311
|
+
*,
|
|
312
|
+
model: pm.Model | None = None,
|
|
313
|
+
use_grad: bool | None = None,
|
|
314
|
+
use_hessp: bool | None = None,
|
|
315
|
+
use_hess: bool | None = None,
|
|
316
|
+
initvals: dict | None = None,
|
|
317
|
+
random_seed: int | np.random.Generator | None = None,
|
|
318
|
+
return_raw: bool = False,
|
|
319
|
+
jitter_rvs: list[TensorVariable] | None = None,
|
|
320
|
+
progressbar: bool = True,
|
|
321
|
+
include_transformed: bool = True,
|
|
322
|
+
gradient_backend: GradientBackend = "pytensor",
|
|
323
|
+
compile_kwargs: dict | None = None,
|
|
324
|
+
**optimizer_kwargs,
|
|
325
|
+
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]:
|
|
326
|
+
"""
|
|
327
|
+
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.minimize.
|
|
328
|
+
|
|
329
|
+
Parameters
|
|
330
|
+
----------
|
|
331
|
+
model : pm.Model
|
|
332
|
+
The PyMC model to be fit. If None, the current model context is used.
|
|
333
|
+
method : str
|
|
334
|
+
The optimization method to use. See scipy.optimize.minimize documentation for details.
|
|
335
|
+
use_grad : bool | None, optional
|
|
336
|
+
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
337
|
+
the ``method``.
|
|
338
|
+
use_hessp : bool | None, optional
|
|
339
|
+
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
|
|
340
|
+
the ``method``.
|
|
341
|
+
use_hess : bool | None, optional
|
|
342
|
+
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
|
|
343
|
+
the ``method``.
|
|
344
|
+
initvals : None | dict, optional
|
|
345
|
+
Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted.
|
|
346
|
+
If None, the model's default initial values are used.
|
|
347
|
+
random_seed : None | int | np.random.Generator, optional
|
|
348
|
+
Seed for the random number generator or a numpy Generator for reproducibility
|
|
349
|
+
return_raw: bool | False, optinal
|
|
350
|
+
Whether to also return the full output of `scipy.optimize.minimize`
|
|
351
|
+
jitter_rvs : list of TensorVariables, optional
|
|
352
|
+
Variables whose initial values should be jittered. If None, all variables are jittered.
|
|
353
|
+
progressbar : bool, optional
|
|
354
|
+
Whether to display a progress bar during optimization. Defaults to True.
|
|
355
|
+
include_transformed: bool, optional
|
|
356
|
+
Whether to include transformed variable values in the returned dictionary. Defaults to True.
|
|
357
|
+
gradient_backend: str, default "pytensor"
|
|
358
|
+
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
|
|
359
|
+
compile_kwargs: dict, optional
|
|
360
|
+
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
|
|
361
|
+
**optimizer_kwargs
|
|
362
|
+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function.
|
|
363
|
+
|
|
364
|
+
Returns
|
|
365
|
+
-------
|
|
366
|
+
optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult]
|
|
367
|
+
Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True,
|
|
368
|
+
also returns the object returned by ``scipy.optimize.minimize``.
|
|
369
|
+
"""
|
|
370
|
+
model = pm.modelcontext(model)
|
|
371
|
+
frozen_model = freeze_dims_and_data(model)
|
|
372
|
+
|
|
373
|
+
jitter_rvs = [] if jitter_rvs is None else jitter_rvs
|
|
374
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
375
|
+
|
|
376
|
+
ipfn = make_initial_point_fn(
|
|
377
|
+
model=frozen_model,
|
|
378
|
+
jitter_rvs=set(jitter_rvs),
|
|
379
|
+
return_transformed=True,
|
|
380
|
+
overrides=initvals,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
start_dict = ipfn(random_seed)
|
|
384
|
+
vars_dict = {var.name: var for var in frozen_model.continuous_value_vars}
|
|
385
|
+
initial_params = DictToArrayBijection.map(
|
|
386
|
+
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
|
|
387
|
+
)
|
|
388
|
+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
389
|
+
method, use_grad, use_hess, use_hessp
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
|
|
393
|
+
loss=-frozen_model.logp(jacobian=False),
|
|
394
|
+
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
|
|
395
|
+
initial_point_dict=start_dict,
|
|
396
|
+
use_grad=use_grad,
|
|
397
|
+
use_hess=use_hess,
|
|
398
|
+
use_hessp=use_hessp,
|
|
399
|
+
gradient_backend=gradient_backend,
|
|
400
|
+
compile_kwargs=compile_kwargs,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
args = optimizer_kwargs.pop("args", None)
|
|
404
|
+
|
|
405
|
+
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
|
|
406
|
+
# if so. That is why it is not set here, regardless of user settings.
|
|
407
|
+
optimizer_result = minimize(
|
|
408
|
+
f=f_logp,
|
|
409
|
+
x0=cast(np.ndarray[float], initial_params.data),
|
|
410
|
+
args=args,
|
|
411
|
+
hess=f_hess,
|
|
412
|
+
hessp=f_hessp,
|
|
413
|
+
progressbar=progressbar,
|
|
414
|
+
method=method,
|
|
415
|
+
**optimizer_kwargs,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
419
|
+
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
|
|
420
|
+
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
|
|
421
|
+
DictToArrayBijection.rmap(raveled_optimized)
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
optimized_point = {
|
|
425
|
+
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
if return_raw:
|
|
429
|
+
return optimized_point, optimizer_result
|
|
430
|
+
|
|
431
|
+
return optimized_point
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright 2022 The PyMC Developers
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
from importlib.util import find_spec
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def fit(method, **kwargs):
|
|
18
|
+
"""
|
|
19
|
+
Fit a model with an inference algorithm
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
method : str
|
|
24
|
+
Which inference method to run.
|
|
25
|
+
Supported: pathfinder or laplace
|
|
26
|
+
|
|
27
|
+
kwargs are passed on.
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
arviz.InferenceData
|
|
32
|
+
"""
|
|
33
|
+
if method == "pathfinder":
|
|
34
|
+
if find_spec("blackjax") is None:
|
|
35
|
+
raise RuntimeError("Need BlackJAX to use `pathfinder`")
|
|
36
|
+
|
|
37
|
+
from pymc_extras.inference.pathfinder import fit_pathfinder
|
|
38
|
+
|
|
39
|
+
return fit_pathfinder(**kwargs)
|
|
40
|
+
|
|
41
|
+
if method == "laplace":
|
|
42
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
43
|
+
|
|
44
|
+
return fit_laplace(**kwargs)
|