pymc-extras 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras/__init__.py +2 -0
- pymc_extras/inference/find_map.py +36 -16
- pymc_extras/inference/fit.py +0 -4
- pymc_extras/inference/laplace.py +17 -10
- pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- pymc_extras/model/marginal/marginal_model.py +2 -1
- pymc_extras/model/model_api.py +18 -2
- pymc_extras/statespace/core/compile.py +1 -1
- pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras/version.txt +1 -1
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/METADATA +16 -4
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/RECORD +23 -20
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/WHEEL +1 -1
- tests/model/test_model_api.py +9 -0
- tests/statespace/test_statespace.py +54 -0
- tests/test_find_map.py +19 -14
- tests/test_laplace.py +42 -15
- tests/test_pathfinder.py +135 -7
- pymc_extras/inference/pathfinder.py +0 -134
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/LICENSE +0 -0
- {pymc_extras-0.2.1.dist-info → pymc_extras-0.2.3.dist-info}/top_level.txt +0 -0
pymc_extras/__init__.py
CHANGED
|
@@ -15,7 +15,9 @@ import logging
|
|
|
15
15
|
|
|
16
16
|
from pymc_extras import gp, statespace, utils
|
|
17
17
|
from pymc_extras.distributions import *
|
|
18
|
+
from pymc_extras.inference.find_map import find_MAP
|
|
18
19
|
from pymc_extras.inference.fit import fit
|
|
20
|
+
from pymc_extras.inference.laplace import fit_laplace
|
|
19
21
|
from pymc_extras.model.marginal.marginal_model import (
|
|
20
22
|
MarginalModel,
|
|
21
23
|
marginalize,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
|
|
3
3
|
from collections.abc import Callable
|
|
4
|
+
from importlib.util import find_spec
|
|
4
5
|
from typing import Literal, cast, get_args
|
|
5
6
|
|
|
6
|
-
import jax
|
|
7
7
|
import numpy as np
|
|
8
8
|
import pymc as pm
|
|
9
9
|
import pytensor
|
|
@@ -30,13 +30,29 @@ VALID_BACKENDS = get_args(GradientBackend)
|
|
|
30
30
|
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
31
31
|
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
32
32
|
|
|
33
|
-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
34
|
-
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
|
|
35
|
-
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
|
|
36
|
-
|
|
37
33
|
if use_hess and use_hessp:
|
|
34
|
+
_log.warning(
|
|
35
|
+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
36
|
+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
37
|
+
'Setting "use_hess" to False.'
|
|
38
|
+
)
|
|
38
39
|
use_hess = False
|
|
39
40
|
|
|
41
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
42
|
+
|
|
43
|
+
if use_hessp is not None and use_hess is None:
|
|
44
|
+
use_hess = not use_hessp
|
|
45
|
+
|
|
46
|
+
elif use_hess is not None and use_hessp is None:
|
|
47
|
+
use_hessp = not use_hess
|
|
48
|
+
|
|
49
|
+
elif use_hessp is None and use_hess is None:
|
|
50
|
+
use_hessp = method_info["uses_hessp"]
|
|
51
|
+
use_hess = method_info["uses_hess"]
|
|
52
|
+
if use_hessp and use_hess:
|
|
53
|
+
# If a method could use either hess or hessp, we default to using hessp
|
|
54
|
+
use_hess = False
|
|
55
|
+
|
|
40
56
|
return use_grad, use_hess, use_hessp
|
|
41
57
|
|
|
42
58
|
|
|
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
|
59
75
|
The nearest positive semi-definite matrix to the input matrix.
|
|
60
76
|
"""
|
|
61
77
|
C = (A + A.T) / 2
|
|
62
|
-
eigval, eigvec = np.linalg.
|
|
78
|
+
eigval, eigvec = np.linalg.eigh(C)
|
|
63
79
|
eigval[eigval < 0] = 0
|
|
64
80
|
|
|
65
81
|
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
|
|
|
97
113
|
return f_untransform(posterior_draws)
|
|
98
114
|
|
|
99
115
|
|
|
100
|
-
def
|
|
116
|
+
def _compile_grad_and_hess_to_jax(
|
|
101
117
|
f_loss: Function, use_hess: bool, use_hessp: bool
|
|
102
118
|
) -> tuple[Callable | None, Callable | None]:
|
|
103
119
|
"""
|
|
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
|
|
|
122
138
|
f_hessp: Callable | None
|
|
123
139
|
The compiled hessian-vector product function, or None if use_hessp is False.
|
|
124
140
|
"""
|
|
141
|
+
import jax
|
|
142
|
+
|
|
125
143
|
f_hess = None
|
|
126
144
|
f_hessp = None
|
|
127
145
|
|
|
@@ -152,7 +170,7 @@ def _compile_jax_gradients(
|
|
|
152
170
|
return f_loss_and_grad, f_hess, f_hessp
|
|
153
171
|
|
|
154
172
|
|
|
155
|
-
def
|
|
173
|
+
def _compile_functions_for_scipy_optimize(
|
|
156
174
|
loss: TensorVariable,
|
|
157
175
|
inputs: list[TensorVariable],
|
|
158
176
|
compute_grad: bool,
|
|
@@ -177,7 +195,7 @@ def _compile_functions(
|
|
|
177
195
|
compute_hessp: bool
|
|
178
196
|
Whether to compile a function that computes the Hessian-vector product of the loss function.
|
|
179
197
|
compile_kwargs: dict, optional
|
|
180
|
-
Additional keyword arguments to pass to the ``pm.
|
|
198
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
181
199
|
|
|
182
200
|
Returns
|
|
183
201
|
-------
|
|
@@ -193,19 +211,19 @@ def _compile_functions(
|
|
|
193
211
|
if compute_grad:
|
|
194
212
|
grads = pytensor.gradient.grad(loss, inputs)
|
|
195
213
|
grad = pt.concatenate([grad.ravel() for grad in grads])
|
|
196
|
-
f_loss_and_grad = pm.
|
|
214
|
+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
|
|
197
215
|
else:
|
|
198
|
-
f_loss = pm.
|
|
216
|
+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
|
|
199
217
|
return [f_loss]
|
|
200
218
|
|
|
201
219
|
if compute_hess:
|
|
202
220
|
hess = pytensor.gradient.jacobian(grad, inputs)[0]
|
|
203
|
-
f_hess = pm.
|
|
221
|
+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
|
|
204
222
|
|
|
205
223
|
if compute_hessp:
|
|
206
224
|
p = pt.tensor("p", shape=inputs[0].type.shape)
|
|
207
225
|
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
|
|
208
|
-
f_hessp = pm.
|
|
226
|
+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
|
|
209
227
|
|
|
210
228
|
return [f_loss_and_grad, f_hess, f_hessp]
|
|
211
229
|
|
|
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
240
258
|
gradient_backend: str, default "pytensor"
|
|
241
259
|
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
|
|
242
260
|
compile_kwargs:
|
|
243
|
-
Additional keyword arguments to pass to the ``pm.
|
|
261
|
+
Additional keyword arguments to pass to the ``pm.compile`` function.
|
|
244
262
|
|
|
245
263
|
Returns
|
|
246
264
|
-------
|
|
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
|
|
|
265
283
|
)
|
|
266
284
|
|
|
267
285
|
use_jax_gradients = (gradient_backend == "jax") and use_grad
|
|
286
|
+
if use_jax_gradients and not find_spec("jax"):
|
|
287
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
268
288
|
|
|
269
289
|
mode = compile_kwargs.get("mode", None)
|
|
270
290
|
if mode is None and use_jax_gradients:
|
|
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
285
305
|
compute_hess = use_hess and not use_jax_gradients
|
|
286
306
|
compute_hessp = use_hessp and not use_jax_gradients
|
|
287
307
|
|
|
288
|
-
funcs =
|
|
308
|
+
funcs = _compile_functions_for_scipy_optimize(
|
|
289
309
|
loss=loss,
|
|
290
310
|
inputs=[flat_input],
|
|
291
311
|
compute_grad=compute_grad,
|
|
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
|
|
|
301
321
|
|
|
302
322
|
if use_jax_gradients:
|
|
303
323
|
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
|
|
304
|
-
f_loss, f_hess, f_hessp =
|
|
324
|
+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
|
|
305
325
|
|
|
306
326
|
return f_loss, f_hess, f_hessp
|
|
307
327
|
|
pymc_extras/inference/fit.py
CHANGED
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from importlib.util import find_spec
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
def fit(method, **kwargs):
|
|
@@ -31,9 +30,6 @@ def fit(method, **kwargs):
|
|
|
31
30
|
arviz.InferenceData
|
|
32
31
|
"""
|
|
33
32
|
if method == "pathfinder":
|
|
34
|
-
if find_spec("blackjax") is None:
|
|
35
|
-
raise RuntimeError("Need BlackJAX to use `pathfinder`")
|
|
36
|
-
|
|
37
33
|
from pymc_extras.inference.pathfinder import fit_pathfinder
|
|
38
34
|
|
|
39
35
|
return fit_pathfinder(**kwargs)
|
pymc_extras/inference/laplace.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import logging
|
|
17
17
|
|
|
18
18
|
from functools import reduce
|
|
19
|
+
from importlib.util import find_spec
|
|
19
20
|
from itertools import product
|
|
20
21
|
from typing import Literal
|
|
21
22
|
|
|
@@ -231,7 +232,7 @@ def add_data_to_inferencedata(
|
|
|
231
232
|
return idata
|
|
232
233
|
|
|
233
234
|
|
|
234
|
-
def
|
|
235
|
+
def fit_mvn_at_MAP(
|
|
235
236
|
optimized_point: dict[str, np.ndarray],
|
|
236
237
|
model: pm.Model | None = None,
|
|
237
238
|
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
|
|
@@ -276,6 +277,9 @@ def fit_mvn_to_MAP(
|
|
|
276
277
|
inverse_hessian: np.ndarray
|
|
277
278
|
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
278
279
|
"""
|
|
280
|
+
if gradient_backend == "jax" and not find_spec("jax"):
|
|
281
|
+
raise ImportError("JAX must be installed to use JAX gradients")
|
|
282
|
+
|
|
279
283
|
model = pm.modelcontext(model)
|
|
280
284
|
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
281
285
|
frozen_model = freeze_dims_and_data(model)
|
|
@@ -344,8 +348,10 @@ def sample_laplace_posterior(
|
|
|
344
348
|
|
|
345
349
|
Parameters
|
|
346
350
|
----------
|
|
347
|
-
mu
|
|
348
|
-
|
|
351
|
+
mu: RaveledVars
|
|
352
|
+
The MAP estimate of the model parameters.
|
|
353
|
+
H_inv: np.ndarray
|
|
354
|
+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
|
|
349
355
|
model : Model
|
|
350
356
|
A PyMC model
|
|
351
357
|
chains : int
|
|
@@ -384,9 +390,7 @@ def sample_laplace_posterior(
|
|
|
384
390
|
constrained_rvs, replace={unconstrained_vector: batched_values}
|
|
385
391
|
)
|
|
386
392
|
|
|
387
|
-
f_constrain = pm.
|
|
388
|
-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
|
|
389
|
-
)
|
|
393
|
+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
|
|
390
394
|
posterior_draws = f_constrain(posterior_draws)
|
|
391
395
|
|
|
392
396
|
else:
|
|
@@ -472,15 +476,17 @@ def fit_laplace(
|
|
|
472
476
|
and 1).
|
|
473
477
|
|
|
474
478
|
.. warning::
|
|
475
|
-
This
|
|
479
|
+
This argument should be considered highly experimental. It has not been verified if this method produces
|
|
476
480
|
valid draws from the posterior. **Use at your own risk**.
|
|
477
481
|
|
|
478
482
|
gradient_backend: str, default "pytensor"
|
|
479
483
|
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
|
480
484
|
chains: int, default: 2
|
|
481
|
-
The number of
|
|
485
|
+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
|
|
486
|
+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
|
|
487
|
+
compatible with the ArviZ library.
|
|
482
488
|
draws: int, default: 500
|
|
483
|
-
The number of samples to draw from the approximated posterior.
|
|
489
|
+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
|
|
484
490
|
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
|
|
485
491
|
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
|
|
486
492
|
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
|
|
@@ -547,11 +553,12 @@ def fit_laplace(
|
|
|
547
553
|
**optimizer_kwargs,
|
|
548
554
|
)
|
|
549
555
|
|
|
550
|
-
mu, H_inv =
|
|
556
|
+
mu, H_inv = fit_mvn_at_MAP(
|
|
551
557
|
optimized_point=optimized_point,
|
|
552
558
|
model=model,
|
|
553
559
|
on_bad_cov=on_bad_cov,
|
|
554
560
|
transform_samples=fit_in_unconstrained_space,
|
|
561
|
+
gradient_backend=gradient_backend,
|
|
555
562
|
zero_tol=zero_tol,
|
|
556
563
|
diag_jitter=diag_jitter,
|
|
557
564
|
compile_kwargs=compile_kwargs,
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings as _warnings
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import arviz as az
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
from scipy.special import logsumexp
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class ImportanceSamplingResult:
|
|
18
|
+
"""container for importance sampling results"""
|
|
19
|
+
|
|
20
|
+
samples: NDArray
|
|
21
|
+
pareto_k: float | None = None
|
|
22
|
+
warnings: list[str] = field(default_factory=list)
|
|
23
|
+
method: str = "none"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def importance_sampling(
|
|
27
|
+
samples: NDArray,
|
|
28
|
+
logP: NDArray,
|
|
29
|
+
logQ: NDArray,
|
|
30
|
+
num_draws: int,
|
|
31
|
+
method: Literal["psis", "psir", "identity", "none"] | None,
|
|
32
|
+
random_seed: int | None = None,
|
|
33
|
+
) -> ImportanceSamplingResult:
|
|
34
|
+
"""Pareto Smoothed Importance Resampling (PSIR)
|
|
35
|
+
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
samples : NDArray
|
|
40
|
+
samples from proposal distribution, shape (L, M, N)
|
|
41
|
+
logP : NDArray
|
|
42
|
+
log probability values of target distribution, shape (L, M)
|
|
43
|
+
logQ : NDArray
|
|
44
|
+
log probability values of proposal distribution, shape (L, M)
|
|
45
|
+
num_draws : int
|
|
46
|
+
number of draws to return where num_draws <= samples.shape[0]
|
|
47
|
+
method : str, optional
|
|
48
|
+
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
|
|
49
|
+
random_seed : int | None
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
ImportanceSamplingResult
|
|
54
|
+
importance sampled draws and other info based on the specified method
|
|
55
|
+
|
|
56
|
+
Future work!
|
|
57
|
+
----------
|
|
58
|
+
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
|
|
59
|
+
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
|
|
60
|
+
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
|
|
61
|
+
|
|
62
|
+
References
|
|
63
|
+
----------
|
|
64
|
+
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
|
|
65
|
+
|
|
66
|
+
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
|
|
67
|
+
|
|
68
|
+
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
warnings = []
|
|
72
|
+
num_paths, _, N = samples.shape
|
|
73
|
+
|
|
74
|
+
if method == "none":
|
|
75
|
+
warnings.append(
|
|
76
|
+
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
|
|
77
|
+
)
|
|
78
|
+
return ImportanceSamplingResult(samples=samples, warnings=warnings)
|
|
79
|
+
else:
|
|
80
|
+
samples = samples.reshape(-1, N)
|
|
81
|
+
logP = logP.ravel()
|
|
82
|
+
logQ = logQ.ravel()
|
|
83
|
+
|
|
84
|
+
# adjust log densities
|
|
85
|
+
log_I = np.log(num_paths)
|
|
86
|
+
logP -= log_I
|
|
87
|
+
logQ -= log_I
|
|
88
|
+
logiw = logP - logQ
|
|
89
|
+
|
|
90
|
+
with _warnings.catch_warnings():
|
|
91
|
+
_warnings.filterwarnings(
|
|
92
|
+
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
|
|
93
|
+
)
|
|
94
|
+
if method == "psis":
|
|
95
|
+
replace = False
|
|
96
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
97
|
+
elif method == "psir":
|
|
98
|
+
replace = True
|
|
99
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
100
|
+
elif method == "identity":
|
|
101
|
+
replace = False
|
|
102
|
+
pareto_k = None
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Invalid importance sampling method: {method}")
|
|
105
|
+
|
|
106
|
+
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
|
|
107
|
+
# Pareto k may not be a good diagnostic for Pathfinder.
|
|
108
|
+
# TODO: Find replacement diagnostics for Pathfinder.
|
|
109
|
+
|
|
110
|
+
p = np.exp(logiw - logsumexp(logiw))
|
|
111
|
+
rng = np.random.default_rng(random_seed)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
|
|
115
|
+
return ImportanceSamplingResult(
|
|
116
|
+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
|
|
117
|
+
)
|
|
118
|
+
except ValueError as e1:
|
|
119
|
+
if "Fewer non-zero entries in p than size" in str(e1):
|
|
120
|
+
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
|
|
121
|
+
warnings.append(
|
|
122
|
+
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
|
|
123
|
+
)
|
|
124
|
+
try:
|
|
125
|
+
resampled = rng.choice(
|
|
126
|
+
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
|
|
127
|
+
)
|
|
128
|
+
return ImportanceSamplingResult(
|
|
129
|
+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
|
|
130
|
+
)
|
|
131
|
+
except ValueError as e2:
|
|
132
|
+
logger.error(
|
|
133
|
+
"Importance sampling failed even with psir importance sampling. "
|
|
134
|
+
"This might indicate invalid probability weights or insufficient valid samples."
|
|
135
|
+
)
|
|
136
|
+
raise ValueError(
|
|
137
|
+
"Importance sampling failed for both with and without replacement"
|
|
138
|
+
) from e2
|
|
139
|
+
raise
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum, auto
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from scipy.optimize import minimize
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(slots=True)
|
|
16
|
+
class LBFGSHistory:
|
|
17
|
+
"""History of LBFGS iterations."""
|
|
18
|
+
|
|
19
|
+
x: NDArray[np.float64]
|
|
20
|
+
g: NDArray[np.float64]
|
|
21
|
+
count: int
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
self.x = np.ascontiguousarray(self.x, dtype=np.float64)
|
|
25
|
+
self.g = np.ascontiguousarray(self.g, dtype=np.float64)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class LBFGSHistoryManager:
|
|
30
|
+
"""manages and stores the history of lbfgs optimisation iterations.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
value_grad_fn : Callable
|
|
35
|
+
function that returns tuple of (value, gradient) given input x
|
|
36
|
+
x0 : NDArray
|
|
37
|
+
initial position
|
|
38
|
+
maxiter : int
|
|
39
|
+
maximum number of iterations to store
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
|
|
43
|
+
x0: NDArray[np.float64]
|
|
44
|
+
maxiter: int
|
|
45
|
+
x_history: NDArray[np.float64] = field(init=False)
|
|
46
|
+
g_history: NDArray[np.float64] = field(init=False)
|
|
47
|
+
count: int = field(init=False)
|
|
48
|
+
|
|
49
|
+
def __post_init__(self) -> None:
|
|
50
|
+
self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
|
|
51
|
+
self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
|
|
52
|
+
self.count = 0
|
|
53
|
+
|
|
54
|
+
value, grad = self.value_grad_fn(self.x0)
|
|
55
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value):
|
|
56
|
+
self.add_entry(self.x0, grad)
|
|
57
|
+
|
|
58
|
+
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
|
|
59
|
+
"""adds new position and gradient to history.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
x : NDArray
|
|
64
|
+
position vector
|
|
65
|
+
g : NDArray
|
|
66
|
+
gradient vector
|
|
67
|
+
"""
|
|
68
|
+
self.x_history[self.count] = x
|
|
69
|
+
self.g_history[self.count] = g
|
|
70
|
+
self.count += 1
|
|
71
|
+
|
|
72
|
+
def get_history(self) -> LBFGSHistory:
|
|
73
|
+
"""returns history of optimisation iterations."""
|
|
74
|
+
return LBFGSHistory(
|
|
75
|
+
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def __call__(self, x: NDArray[np.float64]) -> None:
|
|
79
|
+
value, grad = self.value_grad_fn(x)
|
|
80
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
|
|
81
|
+
self.add_entry(x, grad)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LBFGSStatus(Enum):
|
|
85
|
+
CONVERGED = auto()
|
|
86
|
+
MAX_ITER_REACHED = auto()
|
|
87
|
+
DIVERGED = auto()
|
|
88
|
+
# Statuses that lead to Exceptions:
|
|
89
|
+
INIT_FAILED = auto()
|
|
90
|
+
LBFGS_FAILED = auto()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LBFGSException(Exception):
|
|
94
|
+
DEFAULT_MESSAGE = "LBFGS failed."
|
|
95
|
+
|
|
96
|
+
def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED):
|
|
97
|
+
super().__init__(message or self.DEFAULT_MESSAGE)
|
|
98
|
+
self.status = status
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class LBFGSInitFailed(LBFGSException):
|
|
102
|
+
DEFAULT_MESSAGE = "LBFGS failed to initialise."
|
|
103
|
+
|
|
104
|
+
def __init__(self, message=None):
|
|
105
|
+
super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LBFGS:
|
|
109
|
+
"""L-BFGS optimizer wrapper around scipy's implementation.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
value_grad_fn : Callable
|
|
114
|
+
function that returns tuple of (value, gradient) given input x
|
|
115
|
+
maxcor : int
|
|
116
|
+
maximum number of variable metric corrections
|
|
117
|
+
maxiter : int, optional
|
|
118
|
+
maximum number of iterations, defaults to 1000
|
|
119
|
+
ftol : float, optional
|
|
120
|
+
function tolerance for convergence, defaults to 1e-5
|
|
121
|
+
gtol : float, optional
|
|
122
|
+
gradient tolerance for convergence, defaults to 1e-8
|
|
123
|
+
maxls : int, optional
|
|
124
|
+
maximum number of line search steps, defaults to 1000
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
|
|
129
|
+
) -> None:
|
|
130
|
+
self.value_grad_fn = value_grad_fn
|
|
131
|
+
self.maxcor = maxcor
|
|
132
|
+
self.maxiter = maxiter
|
|
133
|
+
self.ftol = ftol
|
|
134
|
+
self.gtol = gtol
|
|
135
|
+
self.maxls = maxls
|
|
136
|
+
|
|
137
|
+
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
|
|
138
|
+
"""minimizes objective function starting from initial position.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
x0 : array_like
|
|
143
|
+
initial position
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
x : NDArray
|
|
148
|
+
history of positions
|
|
149
|
+
g : NDArray
|
|
150
|
+
history of gradients
|
|
151
|
+
count : int
|
|
152
|
+
number of iterations
|
|
153
|
+
status : LBFGSStatus
|
|
154
|
+
final status of optimisation
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
x0 = np.array(x0, dtype=np.float64)
|
|
158
|
+
|
|
159
|
+
history_manager = LBFGSHistoryManager(
|
|
160
|
+
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
result = minimize(
|
|
164
|
+
self.value_grad_fn,
|
|
165
|
+
x0,
|
|
166
|
+
method="L-BFGS-B",
|
|
167
|
+
jac=True,
|
|
168
|
+
callback=history_manager,
|
|
169
|
+
options={
|
|
170
|
+
"maxcor": self.maxcor,
|
|
171
|
+
"maxiter": self.maxiter,
|
|
172
|
+
"ftol": self.ftol,
|
|
173
|
+
"gtol": self.gtol,
|
|
174
|
+
"maxls": self.maxls,
|
|
175
|
+
},
|
|
176
|
+
)
|
|
177
|
+
history = history_manager.get_history()
|
|
178
|
+
|
|
179
|
+
# warnings and suggestions for LBFGSStatus are displayed at the end
|
|
180
|
+
if result.status == 1:
|
|
181
|
+
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
|
|
182
|
+
elif (result.status == 2) or (history.count <= 1):
|
|
183
|
+
if result.nit <= 1:
|
|
184
|
+
lbfgs_status = LBFGSStatus.INIT_FAILED
|
|
185
|
+
elif result.fun == np.inf:
|
|
186
|
+
lbfgs_status = LBFGSStatus.DIVERGED
|
|
187
|
+
else:
|
|
188
|
+
lbfgs_status = LBFGSStatus.CONVERGED
|
|
189
|
+
|
|
190
|
+
return history.x, history.g, history.count, lbfgs_status
|