pymc-extras 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +2 -0
- pymc_extras/inference/find_map.py +36 -16
- pymc_extras/inference/laplace.py +17 -10
- pymc_extras/model/marginal/marginal_model.py +2 -1
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.2.dist-info → pymc_extras-0.2.3.dist-info}/METADATA +4 -3
- {pymc_extras-0.2.2.dist-info → pymc_extras-0.2.3.dist-info}/RECORD +13 -13
- tests/test_find_map.py +19 -14
- tests/test_laplace.py +42 -15
- {pymc_extras-0.2.2.dist-info → pymc_extras-0.2.3.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.2.dist-info → pymc_extras-0.2.3.dist-info}/WHEEL +0 -0
- {pymc_extras-0.2.2.dist-info → pymc_extras-0.2.3.dist-info}/top_level.txt +0 -0
pymc_extras/__init__.py
CHANGED
|
@@ -15,7 +15,9 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
from pymc_extras import gp, statespace, utils
|
|
17
17
|
from pymc_extras.distributions import *
|
|
18
|
+
from pymc_extras.inference.find_map import find_MAP
|
|
18
19
|
from pymc_extras.inference.fit import fit
|
|
20
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
19
21
|
from pymc_extras.model.marginal.marginal_model import (
|
|
20
22
|
MarginalModel,
|
|
21
23
|
marginalize,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from importlib.util import find_spec
|
|
4
5
|
from typing import Literal, cast, get_args
|
|
5
6
|
|
|
6
|
-
import jax
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pymc as pm
|
|
9
9
|
import pytensor
|
|
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
|
|
|
30
30
|
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
31
31
|
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
32
32
|
|
|
33
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
34
|
-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
|
|
35
|
-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
|
|
36
|
-
|
|
37
33
|
if use_hess and use_hessp:
|
|
34
|
+
_log.warning(
|
|
35
|
+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
36
|
+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
37
|
+
'Setting "use_hess" to False.'
|
|
38
|
+
)
|
|
38
39
|
use_hess = False
|
|
39
40
|
|
|
41
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
42
|
+
|
|
43
|
+
if use_hessp is not None and use_hess is None:
|
|
44
|
+
use_hess = not use_hessp
|
|
45
|
+
|
|
46
|
+
elif use_hess is not None and use_hessp is None:
|
|
47
|
+
use_hessp = not use_hess
|
|
48
|
+
|
|
49
|
+
elif use_hessp is None and use_hess is None:
|
|
50
|
+
use_hessp = method_info["uses_hessp"]
|
|
51
|
+
use_hess = method_info["uses_hess"]
|
|
52
|
+
if use_hessp and use_hess:
|
|
53
|
+
# If a method could use either hess or hessp, we default to using hessp
|
|
54
|
+
use_hess = False
|
|
55
|
+
|
|
40
56
|
return use_grad, use_hess, use_hessp
|
|
41
57
|
|
|
42
58
|
|
|
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
|
59
75
|
The nearest positive semi-definite matrix to the input matrix.
|
|
60
76
|
"""
|
|
61
77
|
C = (A + A.T) / 2
|
|
62
|
-
eigval, eigvec = np.linalg.
|
|
78
|
+
eigval, eigvec = np.linalg.eigh(C)
|
|
63
79
|
eigval[eigval < 0] = 0
|
|
64
80
|
|
|
65
81
|
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
|
|
|
97
113
|
return f_untransform(posterior_draws)
|
|
98
114
|
|
|
99
115
|
|
|
100
|
-
def
|
|
116
|
+
def _compile_grad_and_hess_to_jax(
|
|
101
117
|
f_loss: Function, use_hess: bool, use_hessp: bool
|
|
102
118
|
) -> tuple[Callable | None, Callable | None]:
|
|
103
119
|
"""
|
|
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
|
|
|
122
138
|
f_hessp: Callable | None
|
|
123
139
|
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
124
140
|
"""
|
|
141
|
+
import jax
|
|
142
|
+
|
|
125
143
|
f_hess = None
|
|
126
144
|
f_hessp = None
|
|
127
145
|
|
|
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
|
|
|
152
170
|
return f_loss_and_grad, f_hess, f_hessp
|
|
153
171
|
|
|
154
172
|
|
|
155
|
-
def
|
|
173
|
+
def _compile_functions_for_scipy_optimize(
|
|
156
174
|
loss: TensorVariable,
|
|
157
175
|
inputs: list[TensorVariable],
|
|
158
176
|
compute_grad: bool,
|
|
@@ -177,7 +195,7 @@ def _compile_functions(
|
|
|
177
195
|
compute_hessp: bool
|
|
178
196
|
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
179
197
|
compile_kwargs: dict, optional
|
|
180
|
-
Additional keyword arguments to pass to the ``pm.
|
|
198
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
181
199
|
|
|
182
200
|
Returns
|
|
183
201
|
-------
|
|
@@ -193,19 +211,19 @@ def _compile_functions(
|
|
|
193
211
|
if compute_grad:
|
|
194
212
|
grads = pytensor.gradient.grad(loss, inputs)
|
|
195
213
|
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
196
|
-
f_loss_and_grad = pm.
|
|
214
|
+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
|
|
197
215
|
else:
|
|
198
|
-
f_loss = pm.
|
|
216
|
+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
|
|
199
217
|
return [f_loss]
|
|
200
218
|
|
|
201
219
|
if compute_hess:
|
|
202
220
|
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
203
|
-
f_hess = pm.
|
|
221
|
+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
|
|
204
222
|
|
|
205
223
|
if compute_hessp:
|
|
206
224
|
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
207
225
|
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
208
|
-
f_hessp = pm.
|
|
226
|
+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
|
|
209
227
|
|
|
210
228
|
return [f_loss_and_grad, f_hess, f_hessp]
|
|
211
229
|
|
|
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
240
258
|
gradient_backend: str, default "pytensor"
|
|
241
259
|
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
242
260
|
compile_kwargs:
|
|
243
|
-
Additional keyword arguments to pass to the ``pm.
|
|
261
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
244
262
|
|
|
245
263
|
Returns
|
|
246
264
|
-------
|
|
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
|
|
|
265
283
|
)
|
|
266
284
|
|
|
267
285
|
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
286
|
+
if use_jax_gradients and not find_spec("jax"):
|
|
287
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
268
288
|
|
|
269
289
|
mode = compile_kwargs.get("mode", None)
|
|
270
290
|
if mode is None and use_jax_gradients:
|
|
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
285
305
|
compute_hess = use_hess and not use_jax_gradients
|
|
286
306
|
compute_hessp = use_hessp and not use_jax_gradients
|
|
287
307
|
|
|
288
|
-
funcs =
|
|
308
|
+
funcs = _compile_functions_for_scipy_optimize(
|
|
289
309
|
loss=loss,
|
|
290
310
|
inputs=[flat_input],
|
|
291
311
|
compute_grad=compute_grad,
|
|
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
301
321
|
|
|
302
322
|
if use_jax_gradients:
|
|
303
323
|
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
|
|
304
|
-
f_loss, f_hess, f_hessp =
|
|
324
|
+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
|
|
305
325
|
|
|
306
326
|
return f_loss, f_hess, f_hessp
|
|
307
327
|
|
pymc_extras/inference/laplace.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from functools import reduce
|
|
19
|
+
from importlib.util import find_spec
|
|
19
20
|
from itertools import product
|
|
20
21
|
from typing import Literal
|
|
21
22
|
|
|
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
|
|
|
231
232
|
return idata
|
|
232
233
|
|
|
233
234
|
|
|
234
|
-
def
|
|
235
|
+
def fit_mvn_at_MAP(
|
|
235
236
|
optimized_point: dict[str, np.ndarray],
|
|
236
237
|
model: pm.Model | None = None,
|
|
237
238
|
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
|
|
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
|
|
|
276
277
|
inverse_hessian: np.ndarray
|
|
277
278
|
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
278
279
|
"""
|
|
280
|
+
if gradient_backend == "jax" and not find_spec("jax"):
|
|
281
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
282
|
+
|
|
279
283
|
model = pm.modelcontext(model)
|
|
280
284
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
281
285
|
frozen_model = freeze_dims_and_data(model)
|
|
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
|
|
|
344
348
|
|
|
345
349
|
Parameters
|
|
346
350
|
----------
|
|
347
|
-
mu
|
|
348
|
-
|
|
351
|
+
mu: RaveledVars
|
|
352
|
+
The MAP estimate of the model parameters.
|
|
353
|
+
H_inv: np.ndarray
|
|
354
|
+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
349
355
|
model : Model
|
|
350
356
|
A PyMC model
|
|
351
357
|
chains : int
|
|
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
|
|
|
384
390
|
constrained_rvs, replace={unconstrained_vector: batched_values}
|
|
385
391
|
)
|
|
386
392
|
|
|
387
|
-
f_constrain = pm.
|
|
388
|
-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
|
|
389
|
-
)
|
|
393
|
+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
|
|
390
394
|
posterior_draws = f_constrain(posterior_draws)
|
|
391
395
|
|
|
392
396
|
else:
|
|
@@ -472,15 +476,17 @@ def fit_laplace(
|
|
|
472
476
|
and 1).
|
|
473
477
|
|
|
474
478
|
.. warning::
|
|
475
|
-
This
|
|
479
|
+
This argument should be considered highly experimental. It has not been verified if this method produces
|
|
476
480
|
valid draws from the posterior. **Use at your own risk**.
|
|
477
481
|
|
|
478
482
|
gradient_backend: str, default "pytensor"
|
|
479
483
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
480
484
|
chains: int, default: 2
|
|
481
|
-
The number of
|
|
485
|
+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
486
|
+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
487
|
+
compatible with the ArviZ library.
|
|
482
488
|
draws: int, default: 500
|
|
483
|
-
The number of samples to draw from the approximated posterior.
|
|
489
|
+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
|
|
484
490
|
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
|
|
485
491
|
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
|
|
486
492
|
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
|
|
@@ -547,11 +553,12 @@ def fit_laplace(
|
|
|
547
553
|
**optimizer_kwargs,
|
|
548
554
|
)
|
|
549
555
|
|
|
550
|
-
mu, H_inv =
|
|
556
|
+
mu, H_inv = fit_mvn_at_MAP(
|
|
551
557
|
optimized_point=optimized_point,
|
|
552
558
|
model=model,
|
|
553
559
|
on_bad_cov=on_bad_cov,
|
|
554
560
|
transform_samples=fit_in_unconstrained_space,
|
|
561
|
+
gradient_backend=gradient_backend,
|
|
555
562
|
zero_tol=zero_tol,
|
|
556
563
|
diag_jitter=diag_jitter,
|
|
557
564
|
compile_kwargs=compile_kwargs,
|
|
@@ -19,7 +19,8 @@ from pymc.model.fgraph import (
|
|
|
19
19
|
model_free_rv,
|
|
20
20
|
model_from_fgraph,
|
|
21
21
|
)
|
|
22
|
-
from pymc.pytensorf import collect_default_updates,
|
|
22
|
+
from pymc.pytensorf import collect_default_updates, constant_fold, toposort_replace
|
|
23
|
+
from pymc.pytensorf import compile as compile_pymc
|
|
23
24
|
from pymc.util import RandomState, _get_seeds_per_chain
|
|
24
25
|
from pytensor import In, Out
|
|
25
26
|
from pytensor.compile import SharedVariable
|
|
@@ -30,7 +30,7 @@ def compile_statespace(
|
|
|
30
30
|
|
|
31
31
|
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
|
|
32
32
|
|
|
33
|
-
_f = pm.
|
|
33
|
+
_f = pm.compile(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
|
|
34
34
|
|
|
35
35
|
def f(*, draws=1, **params):
|
|
36
36
|
if isinstance(steps, pt.Variable):
|
pymc_extras/version.txt
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
0.2.
|
|
1
|
+
0.2.3
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.3
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
7
7
|
Maintainer-email: pymc.devs@gmail.com
|
|
8
|
-
License: Apache
|
|
8
|
+
License: Apache-2.0
|
|
9
9
|
Classifier: Development Status :: 5 - Production/Stable
|
|
10
10
|
Classifier: Programming Language :: Python
|
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
|
@@ -20,8 +20,9 @@ Classifier: Operating System :: OS Independent
|
|
|
20
20
|
Requires-Python: >=3.10
|
|
21
21
|
Description-Content-Type: text/markdown
|
|
22
22
|
License-File: LICENSE
|
|
23
|
-
Requires-Dist: pymc>=5.
|
|
23
|
+
Requires-Dist: pymc>=5.20
|
|
24
24
|
Requires-Dist: scikit-learn
|
|
25
|
+
Requires-Dist: better-optimize
|
|
25
26
|
Provides-Extra: dask-histogram
|
|
26
27
|
Requires-Dist: dask[complete]; extra == "dask-histogram"
|
|
27
28
|
Requires-Dist: xhistogram; extra == "dask-histogram"
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
pymc_extras/__init__.py,sha256=
|
|
1
|
+
pymc_extras/__init__.py,sha256=IFIEZdPX_Ugq57Bu7jlyrJLpKng-P0FBAAAzl2pFXLE,1266
|
|
2
2
|
pymc_extras/linearmodel.py,sha256=6eitl15Ec15mSZu7zoHZ7Wwy4U1DPwqfAgwEt6ILeIc,3920
|
|
3
3
|
pymc_extras/model_builder.py,sha256=sAw77fxdiy046BvDPjocuMlbJ0Efj-CDAGtmcwYmoG0,26361
|
|
4
4
|
pymc_extras/printing.py,sha256=G8mj9dRd6i0PcsbcEWZm56ek6V8mmil78RI4MUhywBs,6506
|
|
5
5
|
pymc_extras/version.py,sha256=VxPGCBzhtSegu-Jp5cjzn0n4DGU0wuPUh-KyZKB6uPM,240
|
|
6
|
-
pymc_extras/version.txt,sha256=
|
|
6
|
+
pymc_extras/version.txt,sha256=OrlMBNJJhvOvKIuhzaLAu928Wonf8JcYKAX1RXjh6nU,6
|
|
7
7
|
pymc_extras/distributions/__init__.py,sha256=gTX7tvX8NcgP7V72URV7GeqF1aAEjGVbuW8LMxhXceY,1295
|
|
8
8
|
pymc_extras/distributions/continuous.py,sha256=z-nvQgGncYISdRY8cWsa-56V0bQGq70jYwU-i8VZ0Uk,11253
|
|
9
9
|
pymc_extras/distributions/discrete.py,sha256=vrARNuiQAEXrs7yQgImV1PO8AV1uyEC_LBhr6F9IcOg,13032
|
|
@@ -14,9 +14,9 @@ pymc_extras/distributions/multivariate/r2d2m2cp.py,sha256=bUj9bB-hQi6CpaJfvJjgNP
|
|
|
14
14
|
pymc_extras/gp/__init__.py,sha256=sFHw2y3lEl5tG_FDQHZUonQ_k0DF1JRf0Rp8dpHmge0,745
|
|
15
15
|
pymc_extras/gp/latent_approx.py,sha256=cDEMM6H1BL2qyKg7BZU-ISrKn2HJe7hDaM4Y8GgQDf4,6682
|
|
16
16
|
pymc_extras/inference/__init__.py,sha256=5cXpaQQnW0mJJ3x8wSxmYu63l--Xab5D_gMtjA6Q3uU,666
|
|
17
|
-
pymc_extras/inference/find_map.py,sha256=
|
|
17
|
+
pymc_extras/inference/find_map.py,sha256=vl5l0ei48PnX-uTuHVTr-9QpCEHc8xog-KK6sOnJ8LU,16513
|
|
18
18
|
pymc_extras/inference/fit.py,sha256=S9R48dh74s6K0MC9Iys4NAwVjP6rVRfx6SF-kPiR70E,1165
|
|
19
|
-
pymc_extras/inference/laplace.py,sha256=
|
|
19
|
+
pymc_extras/inference/laplace.py,sha256=uOZGp8ssQuhvCHV_Y_v3icsr4rhcYgr_qlr9dS7pcSM,21761
|
|
20
20
|
pymc_extras/inference/pathfinder/__init__.py,sha256=FhAYrCWNx_dCrynEdjg2CZ9tIinvcVLBm67pNx_Y3kA,101
|
|
21
21
|
pymc_extras/inference/pathfinder/importance_sampling.py,sha256=VvmuaE3aw_Mo3tMwswfF0rqe19mnhOCpzIScaJzjA1Y,6159
|
|
22
22
|
pymc_extras/inference/pathfinder/lbfgs.py,sha256=P0UIOVtspdLzDU6alK-y91qzVAzXjYAXPuGmZ1nRqMo,5715
|
|
@@ -28,14 +28,14 @@ pymc_extras/model/model_api.py,sha256=UHMfQXxWBujeSiUySU0fDUC5Sd_BjT8FoVz3iBxQH_
|
|
|
28
28
|
pymc_extras/model/marginal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
pymc_extras/model/marginal/distributions.py,sha256=iM1yT7_BmivgUSloQPKE2QXGPgjvLqDMY_OTBGsdAWg,15563
|
|
30
30
|
pymc_extras/model/marginal/graph_analysis.py,sha256=0hWUH_PjfpgneQ3NaT__pWHS1fh50zNbI86kH4Nub0E,15693
|
|
31
|
-
pymc_extras/model/marginal/marginal_model.py,sha256=
|
|
31
|
+
pymc_extras/model/marginal/marginal_model.py,sha256=oIdikaSnefCkyMxmzAe222qGXNucxZpHYk7548fK6iA,23631
|
|
32
32
|
pymc_extras/model/transforms/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
33
|
pymc_extras/model/transforms/autoreparam.py,sha256=_NltGWmNqi_X9sHCqAvWcBveLTPxVy11-wENFTcN6kk,12377
|
|
34
34
|
pymc_extras/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
35
35
|
pymc_extras/preprocessing/standard_scaler.py,sha256=Vajp33ma6OkwlU54JYtSS8urHbMJ3CRiRFxZpvFNuus,600
|
|
36
36
|
pymc_extras/statespace/__init__.py,sha256=0MtZj7yT6jcyERvITnn-nkhyY8fO6Za4_vV53CF6ND0,429
|
|
37
37
|
pymc_extras/statespace/core/__init__.py,sha256=huHEiXAm8zV2MZyZ8GBHp6q7_fnWqveM7lC6ilpb3iE,309
|
|
38
|
-
pymc_extras/statespace/core/compile.py,sha256=
|
|
38
|
+
pymc_extras/statespace/core/compile.py,sha256=9FZfE8Bi3VfElxujfOIKRVvmyL9M5R0WfNEqPc5kbVQ,1603
|
|
39
39
|
pymc_extras/statespace/core/representation.py,sha256=DwNIun6wdeEA20oWBx5M4govyWTf5JI87aGQ_E6Mb4U,18956
|
|
40
40
|
pymc_extras/statespace/core/statespace.py,sha256=K_WVnWKlI6sR2kgriq9sctQVvwXCeAirm14TthDpmRM,96860
|
|
41
41
|
pymc_extras/statespace/filters/__init__.py,sha256=N9Q4D0gAq_ZtT-GtrqiX1HkSg6Orv7o1TbrWUtnbTJE,420
|
|
@@ -61,9 +61,9 @@ pymc_extras/utils/prior.py,sha256=QlWVr7uKIK9VncBw7Fz3YgaASKGDfqpORZHc-vz_9gQ,68
|
|
|
61
61
|
pymc_extras/utils/spline.py,sha256=qGq0gcoMG5dpdazKFzG0RXkkCWP8ADPPXN-653-oFn4,4820
|
|
62
62
|
tests/__init__.py,sha256=-ree9OWVCyTeXLR944OWjrQX2os15HXrRNkhJ7QdRjc,603
|
|
63
63
|
tests/test_blackjax_smc.py,sha256=jcNgcMBxaKyPg9UvHnWQtwoL79LXlSpZfALe3RGEZnQ,7233
|
|
64
|
-
tests/test_find_map.py,sha256=
|
|
64
|
+
tests/test_find_map.py,sha256=B8ThnXNyfTQeem24QaLoTitFrsxKoq2VQINUdOwzna0,3379
|
|
65
65
|
tests/test_histogram_approximation.py,sha256=w-xb2Rr0Qft6sm6F3BTmXXnpuqyefC1SUL6YxzqA5X4,4674
|
|
66
|
-
tests/test_laplace.py,sha256=
|
|
66
|
+
tests/test_laplace.py,sha256=u4o-0y4v1emaTMYr_rOyL_EKY_bQIz0DUXFuwuDbfNg,9314
|
|
67
67
|
tests/test_linearmodel.py,sha256=iB8ApNqIX9_nUHoo-Tm51xuPdrva5t4VLLut6qXB5Ao,6906
|
|
68
68
|
tests/test_model_builder.py,sha256=QiINEihBR9rx8xM4Nqlg4urZKoyo58aTKDtxl9SJF1s,11249
|
|
69
69
|
tests/test_pathfinder.py,sha256=GnSbZJ9QuFW9UVbkWaVgMVqQZTCttOyz_rSflxhQ-EA,4955
|
|
@@ -98,8 +98,8 @@ tests/statespace/utilities/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJ
|
|
|
98
98
|
tests/statespace/utilities/shared_fixtures.py,sha256=SNw8Bvj1Yw11TxAW6n20Bq0B8oaYtVTiFFEVNH_wnp4,164
|
|
99
99
|
tests/statespace/utilities/statsmodel_local_level.py,sha256=SQAzaYaSDwiVhUQ1iWjt4MgfAd54RuzVtnslIs3xdS8,1225
|
|
100
100
|
tests/statespace/utilities/test_helpers.py,sha256=oH24a6Q45NFFFI3Kx9mhKbxsCvo9ErCorKFoTjDB3-4,9159
|
|
101
|
-
pymc_extras-0.2.
|
|
102
|
-
pymc_extras-0.2.
|
|
103
|
-
pymc_extras-0.2.
|
|
104
|
-
pymc_extras-0.2.
|
|
105
|
-
pymc_extras-0.2.
|
|
101
|
+
pymc_extras-0.2.3.dist-info/LICENSE,sha256=WjiLhUKEysJvy5e9jk6WwFv9tmAPtnov1uJ6gcH1kIs,11720
|
|
102
|
+
pymc_extras-0.2.3.dist-info/METADATA,sha256=ZTiMM7hvVRF3O_liRu4Aea_EuxJc4vHfTD2CbRRQrcU,5152
|
|
103
|
+
pymc_extras-0.2.3.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
|
104
|
+
pymc_extras-0.2.3.dist-info/top_level.txt,sha256=D6RkgBiXiZCel0nvsYg_zYEoT_VuwocyIY98EMaulj0,18
|
|
105
|
+
pymc_extras-0.2.3.dist-info/RECORD,,
|
tests/test_find_map.py
CHANGED
|
@@ -54,24 +54,28 @@ def test_jax_functions_from_graph(gradient_backend: GradientBackend):
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
@pytest.mark.parametrize(
|
|
57
|
-
"method, use_grad, use_hess",
|
|
57
|
+
"method, use_grad, use_hess, use_hessp",
|
|
58
58
|
[
|
|
59
|
-
("nelder-mead", False, False),
|
|
60
|
-
("powell", False, False),
|
|
61
|
-
("CG", True, False),
|
|
62
|
-
("BFGS", True, False),
|
|
63
|
-
("L-BFGS-B", True, False),
|
|
64
|
-
("TNC", True, False),
|
|
65
|
-
("SLSQP", True, False),
|
|
66
|
-
("dogleg", True, True),
|
|
67
|
-
("
|
|
68
|
-
("
|
|
69
|
-
("trust-
|
|
70
|
-
("trust-
|
|
59
|
+
("nelder-mead", False, False, False),
|
|
60
|
+
("powell", False, False, False),
|
|
61
|
+
("CG", True, False, False),
|
|
62
|
+
("BFGS", True, False, False),
|
|
63
|
+
("L-BFGS-B", True, False, False),
|
|
64
|
+
("TNC", True, False, False),
|
|
65
|
+
("SLSQP", True, False, False),
|
|
66
|
+
("dogleg", True, True, False),
|
|
67
|
+
("Newton-CG", True, True, False),
|
|
68
|
+
("Newton-CG", True, False, True),
|
|
69
|
+
("trust-ncg", True, True, False),
|
|
70
|
+
("trust-ncg", True, False, True),
|
|
71
|
+
("trust-exact", True, True, False),
|
|
72
|
+
("trust-krylov", True, True, False),
|
|
73
|
+
("trust-krylov", True, False, True),
|
|
74
|
+
("trust-constr", True, True, False),
|
|
71
75
|
],
|
|
72
76
|
)
|
|
73
77
|
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
|
|
74
|
-
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
|
|
78
|
+
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
|
|
75
79
|
extra_kwargs = {}
|
|
76
80
|
if method == "dogleg":
|
|
77
81
|
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
|
|
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
|
|
|
88
92
|
**extra_kwargs,
|
|
89
93
|
use_grad=use_grad,
|
|
90
94
|
use_hess=use_hess,
|
|
95
|
+
use_hessp=use_hessp,
|
|
91
96
|
progressbar=False,
|
|
92
97
|
gradient_backend=gradient_backend,
|
|
93
98
|
compile_kwargs={"mode": "JAX"},
|
tests/test_laplace.py
CHANGED
|
@@ -19,10 +19,10 @@ import pytest
|
|
|
19
19
|
|
|
20
20
|
import pymc_extras as pmx
|
|
21
21
|
|
|
22
|
-
from pymc_extras.inference.find_map import find_MAP
|
|
22
|
+
from pymc_extras.inference.find_map import GradientBackend, find_MAP
|
|
23
23
|
from pymc_extras.inference.laplace import (
|
|
24
24
|
fit_laplace,
|
|
25
|
-
|
|
25
|
+
fit_mvn_at_MAP,
|
|
26
26
|
sample_laplace_posterior,
|
|
27
27
|
)
|
|
28
28
|
|
|
@@ -37,7 +37,11 @@ def rng():
|
|
|
37
37
|
"ignore:hessian will stop negating the output in a future version of PyMC.\n"
|
|
38
38
|
+ "To suppress this warning set `negate_output=False`:FutureWarning",
|
|
39
39
|
)
|
|
40
|
-
|
|
40
|
+
@pytest.mark.parametrize(
|
|
41
|
+
"mode, gradient_backend",
|
|
42
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
43
|
+
)
|
|
44
|
+
def test_laplace(mode, gradient_backend: GradientBackend):
|
|
41
45
|
# Example originates from Bayesian Data Analyses, 3rd Edition
|
|
42
46
|
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
|
|
43
47
|
# Aki Vehtari, and Donald Rubin.
|
|
@@ -55,7 +59,13 @@ def test_laplace():
|
|
|
55
59
|
vars = [mu, logsigma]
|
|
56
60
|
|
|
57
61
|
idata = pmx.fit(
|
|
58
|
-
method="laplace",
|
|
62
|
+
method="laplace",
|
|
63
|
+
optimize_method="trust-ncg",
|
|
64
|
+
draws=draws,
|
|
65
|
+
random_seed=173300,
|
|
66
|
+
chains=1,
|
|
67
|
+
compile_kwargs={"mode": mode},
|
|
68
|
+
gradient_backend=gradient_backend,
|
|
59
69
|
)
|
|
60
70
|
|
|
61
71
|
assert idata.posterior["mu"].shape == (1, draws)
|
|
@@ -71,7 +81,11 @@ def test_laplace():
|
|
|
71
81
|
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4)
|
|
72
82
|
|
|
73
83
|
|
|
74
|
-
|
|
84
|
+
@pytest.mark.parametrize(
|
|
85
|
+
"mode, gradient_backend",
|
|
86
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
87
|
+
)
|
|
88
|
+
def test_laplace_only_fit(mode, gradient_backend: GradientBackend):
|
|
75
89
|
# Example originates from Bayesian Data Analyses, 3rd Edition
|
|
76
90
|
# By Andrew Gelman, John Carlin, Hal Stern, David Dunson,
|
|
77
91
|
# Aki Vehtari, and Donald Rubin.
|
|
@@ -90,8 +104,8 @@ def test_laplace_only_fit():
|
|
|
90
104
|
method="laplace",
|
|
91
105
|
optimize_method="BFGS",
|
|
92
106
|
progressbar=True,
|
|
93
|
-
gradient_backend=
|
|
94
|
-
compile_kwargs={"mode":
|
|
107
|
+
gradient_backend=gradient_backend,
|
|
108
|
+
compile_kwargs={"mode": mode},
|
|
95
109
|
optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100),
|
|
96
110
|
random_seed=173300,
|
|
97
111
|
)
|
|
@@ -111,8 +125,11 @@ def test_laplace_only_fit():
|
|
|
111
125
|
[True, False],
|
|
112
126
|
ids=["transformed", "untransformed"],
|
|
113
127
|
)
|
|
114
|
-
@pytest.mark.parametrize(
|
|
115
|
-
|
|
128
|
+
@pytest.mark.parametrize(
|
|
129
|
+
"mode, gradient_backend",
|
|
130
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
131
|
+
)
|
|
132
|
+
def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend):
|
|
116
133
|
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
|
|
117
134
|
with pm.Model(coords=coords) as model:
|
|
118
135
|
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
|
|
@@ -131,13 +148,13 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
|
|
|
131
148
|
use_hessp=True,
|
|
132
149
|
progressbar=False,
|
|
133
150
|
compile_kwargs=dict(mode=mode),
|
|
134
|
-
gradient_backend=
|
|
151
|
+
gradient_backend=gradient_backend,
|
|
135
152
|
)
|
|
136
153
|
|
|
137
154
|
for value in optimized_point.values():
|
|
138
155
|
assert value.shape == (3,)
|
|
139
156
|
|
|
140
|
-
mu, H_inv =
|
|
157
|
+
mu, H_inv = fit_mvn_at_MAP(
|
|
141
158
|
optimized_point=optimized_point,
|
|
142
159
|
model=model,
|
|
143
160
|
transform_samples=transform_samples,
|
|
@@ -163,7 +180,11 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
|
|
|
163
180
|
]
|
|
164
181
|
|
|
165
182
|
|
|
166
|
-
|
|
183
|
+
@pytest.mark.parametrize(
|
|
184
|
+
"mode, gradient_backend",
|
|
185
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
186
|
+
)
|
|
187
|
+
def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng):
|
|
167
188
|
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}
|
|
168
189
|
with pm.Model(coords=coords) as ragged_dim_model:
|
|
169
190
|
X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"])
|
|
@@ -188,8 +209,8 @@ def test_fit_laplace_ragged_coords(rng):
|
|
|
188
209
|
progressbar=False,
|
|
189
210
|
use_grad=True,
|
|
190
211
|
use_hessp=True,
|
|
191
|
-
gradient_backend=
|
|
192
|
-
compile_kwargs={"mode":
|
|
212
|
+
gradient_backend=gradient_backend,
|
|
213
|
+
compile_kwargs={"mode": mode},
|
|
193
214
|
)
|
|
194
215
|
|
|
195
216
|
assert idata["posterior"].beta.shape[-2:] == (3, 2)
|
|
@@ -206,7 +227,11 @@ def test_fit_laplace_ragged_coords(rng):
|
|
|
206
227
|
[True, False],
|
|
207
228
|
ids=["transformed", "untransformed"],
|
|
208
229
|
)
|
|
209
|
-
|
|
230
|
+
@pytest.mark.parametrize(
|
|
231
|
+
"mode, gradient_backend",
|
|
232
|
+
[(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")],
|
|
233
|
+
)
|
|
234
|
+
def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend):
|
|
210
235
|
with pm.Model() as simp_model:
|
|
211
236
|
mu = pm.Normal("mu", mu=3, sigma=0.5)
|
|
212
237
|
sigma = pm.Exponential("sigma", 1)
|
|
@@ -223,6 +248,8 @@ def test_fit_laplace(fit_in_unconstrained_space):
|
|
|
223
248
|
use_hessp=True,
|
|
224
249
|
fit_in_unconstrained_space=fit_in_unconstrained_space,
|
|
225
250
|
optimizer_kwargs=dict(maxiter=100_000, tol=1e-100),
|
|
251
|
+
compile_kwargs={"mode": mode},
|
|
252
|
+
gradient_backend=gradient_backend,
|
|
226
253
|
)
|
|
227
254
|
|
|
228
255
|
np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|