pymc-extras 0.2.1__tar.gz → 0.2.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/PKG-INFO +13 -2
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/fit.py +0 -4
- pymc_extras-0.2.2/pymc_extras/inference/pathfinder/__init__.py +3 -0
- pymc_extras-0.2.2/pymc_extras/inference/pathfinder/importance_sampling.py +139 -0
- pymc_extras-0.2.2/pymc_extras/inference/pathfinder/lbfgs.py +190 -0
- pymc_extras-0.2.2/pymc_extras/inference/pathfinder/pathfinder.py +1746 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/model_api.py +18 -2
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/core/statespace.py +79 -36
- pymc_extras-0.2.2/pymc_extras/version.txt +1 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras.egg-info/PKG-INFO +13 -2
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras.egg-info/SOURCES.txt +4 -1
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/test_model_api.py +9 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_statespace.py +54 -0
- pymc_extras-0.2.2/tests/test_pathfinder.py +173 -0
- pymc_extras-0.2.1/pymc_extras/inference/pathfinder.py +0 -134
- pymc_extras-0.2.1/pymc_extras/version.txt +0 -1
- pymc_extras-0.2.1/tests/test_pathfinder.py +0 -45
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/LICENSE +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/MANIFEST.in +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/README.md +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/histogram_utils.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/distributions/timeseries.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/find_map.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/laplace.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/marginal/distributions.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/marginal/graph_analysis.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/marginal/marginal_model.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/core/compile.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/filters/distributions.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/SARIMAX.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/VARMAX.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/structural.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/models/utilities.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/statespace/utils/data_tools.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/model_equivalence.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras/version.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras.egg-info/dependency_links.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras.egg-info/requires.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pymc_extras.egg-info/top_level.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/pyproject.toml +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/requirements-dev.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/requirements-docs.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/requirements.txt +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/setup.cfg +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/setup.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/marginal/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/marginal/test_distributions.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/marginal/test_graph_analysis.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/model/marginal/test_marginal_model.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_ETS.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_SARIMAX.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_VARMAX.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_coord_assignment.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_distributions.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_kalman_filter.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_representation.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_statespace_JAX.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/test_structural.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/utilities/__init__.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/utilities/shared_fixtures.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/utilities/statsmodel_local_level.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/statespace/utilities/test_helpers.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_find_map.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_histogram_approximation.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_laplace.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_pivoted_cholesky.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_printing.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/test_splines.py +0 -0
- {pymc_extras-0.2.1 → pymc_extras-0.2.2}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Home-page: http://github.com/pymc-devs/pymc-extras
|
|
6
6
|
Maintainer: PyMC Developers
|
|
@@ -34,6 +34,17 @@ Provides-Extra: dev
|
|
|
34
34
|
Requires-Dist: dask[all]; extra == "dev"
|
|
35
35
|
Requires-Dist: blackjax; extra == "dev"
|
|
36
36
|
Requires-Dist: statsmodels; extra == "dev"
|
|
37
|
+
Dynamic: classifier
|
|
38
|
+
Dynamic: description
|
|
39
|
+
Dynamic: description-content-type
|
|
40
|
+
Dynamic: home-page
|
|
41
|
+
Dynamic: license
|
|
42
|
+
Dynamic: maintainer
|
|
43
|
+
Dynamic: maintainer-email
|
|
44
|
+
Dynamic: provides-extra
|
|
45
|
+
Dynamic: requires-dist
|
|
46
|
+
Dynamic: requires-python
|
|
47
|
+
Dynamic: summary
|
|
37
48
|
|
|
38
49
|
# Welcome to `pymc-extras`
|
|
39
50
|
<a href="https://gitpod.io/#https://github.com/pymc-devs/pymc-extras">
|
|
@@ -11,7 +11,6 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from importlib.util import find_spec
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
def fit(method, **kwargs):
|
|
@@ -31,9 +30,6 @@ def fit(method, **kwargs):
|
|
|
31
30
|
arviz.InferenceData
|
|
32
31
|
"""
|
|
33
32
|
if method == "pathfinder":
|
|
34
|
-
if find_spec("blackjax") is None:
|
|
35
|
-
raise RuntimeError("Need BlackJAX to use `pathfinder`")
|
|
36
|
-
|
|
37
33
|
from pymc_extras.inference.pathfinder import fit_pathfinder
|
|
38
34
|
|
|
39
35
|
return fit_pathfinder(**kwargs)
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import warnings as _warnings
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
import arviz as az
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from numpy.typing import NDArray
|
|
11
|
+
from scipy.special import logsumexp
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class ImportanceSamplingResult:
|
|
18
|
+
"""container for importance sampling results"""
|
|
19
|
+
|
|
20
|
+
samples: NDArray
|
|
21
|
+
pareto_k: float | None = None
|
|
22
|
+
warnings: list[str] = field(default_factory=list)
|
|
23
|
+
method: str = "none"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def importance_sampling(
|
|
27
|
+
samples: NDArray,
|
|
28
|
+
logP: NDArray,
|
|
29
|
+
logQ: NDArray,
|
|
30
|
+
num_draws: int,
|
|
31
|
+
method: Literal["psis", "psir", "identity", "none"] | None,
|
|
32
|
+
random_seed: int | None = None,
|
|
33
|
+
) -> ImportanceSamplingResult:
|
|
34
|
+
"""Pareto Smoothed Importance Resampling (PSIR)
|
|
35
|
+
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
samples : NDArray
|
|
40
|
+
samples from proposal distribution, shape (L, M, N)
|
|
41
|
+
logP : NDArray
|
|
42
|
+
log probability values of target distribution, shape (L, M)
|
|
43
|
+
logQ : NDArray
|
|
44
|
+
log probability values of proposal distribution, shape (L, M)
|
|
45
|
+
num_draws : int
|
|
46
|
+
number of draws to return where num_draws <= samples.shape[0]
|
|
47
|
+
method : str, optional
|
|
48
|
+
importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths.
|
|
49
|
+
random_seed : int | None
|
|
50
|
+
|
|
51
|
+
Returns
|
|
52
|
+
-------
|
|
53
|
+
ImportanceSamplingResult
|
|
54
|
+
importance sampled draws and other info based on the specified method
|
|
55
|
+
|
|
56
|
+
Future work!
|
|
57
|
+
----------
|
|
58
|
+
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
|
|
59
|
+
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
|
|
60
|
+
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.
|
|
61
|
+
|
|
62
|
+
References
|
|
63
|
+
----------
|
|
64
|
+
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668
|
|
65
|
+
|
|
66
|
+
Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538
|
|
67
|
+
|
|
68
|
+
Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
warnings = []
|
|
72
|
+
num_paths, _, N = samples.shape
|
|
73
|
+
|
|
74
|
+
if method == "none":
|
|
75
|
+
warnings.append(
|
|
76
|
+
"Importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability."
|
|
77
|
+
)
|
|
78
|
+
return ImportanceSamplingResult(samples=samples, warnings=warnings)
|
|
79
|
+
else:
|
|
80
|
+
samples = samples.reshape(-1, N)
|
|
81
|
+
logP = logP.ravel()
|
|
82
|
+
logQ = logQ.ravel()
|
|
83
|
+
|
|
84
|
+
# adjust log densities
|
|
85
|
+
log_I = np.log(num_paths)
|
|
86
|
+
logP -= log_I
|
|
87
|
+
logQ -= log_I
|
|
88
|
+
logiw = logP - logQ
|
|
89
|
+
|
|
90
|
+
with _warnings.catch_warnings():
|
|
91
|
+
_warnings.filterwarnings(
|
|
92
|
+
"ignore", category=RuntimeWarning, message="overflow encountered in exp"
|
|
93
|
+
)
|
|
94
|
+
if method == "psis":
|
|
95
|
+
replace = False
|
|
96
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
97
|
+
elif method == "psir":
|
|
98
|
+
replace = True
|
|
99
|
+
logiw, pareto_k = az.psislw(logiw)
|
|
100
|
+
elif method == "identity":
|
|
101
|
+
replace = False
|
|
102
|
+
pareto_k = None
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(f"Invalid importance sampling method: {method}")
|
|
105
|
+
|
|
106
|
+
# NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI.
|
|
107
|
+
# Pareto k may not be a good diagnostic for Pathfinder.
|
|
108
|
+
# TODO: Find replacement diagnostics for Pathfinder.
|
|
109
|
+
|
|
110
|
+
p = np.exp(logiw - logsumexp(logiw))
|
|
111
|
+
rng = np.random.default_rng(random_seed)
|
|
112
|
+
|
|
113
|
+
try:
|
|
114
|
+
resampled = rng.choice(samples, size=num_draws, replace=replace, p=p, shuffle=False, axis=0)
|
|
115
|
+
return ImportanceSamplingResult(
|
|
116
|
+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
|
|
117
|
+
)
|
|
118
|
+
except ValueError as e1:
|
|
119
|
+
if "Fewer non-zero entries in p than size" in str(e1):
|
|
120
|
+
num_nonzero = np.where(np.nonzero(p)[0], 1, 0).sum()
|
|
121
|
+
warnings.append(
|
|
122
|
+
f"Not enough valid samples: {num_nonzero} available out of {num_draws} requested. Switching to psir importance sampling."
|
|
123
|
+
)
|
|
124
|
+
try:
|
|
125
|
+
resampled = rng.choice(
|
|
126
|
+
samples, size=num_draws, replace=True, p=p, shuffle=False, axis=0
|
|
127
|
+
)
|
|
128
|
+
return ImportanceSamplingResult(
|
|
129
|
+
samples=resampled, pareto_k=pareto_k, warnings=warnings, method=method
|
|
130
|
+
)
|
|
131
|
+
except ValueError as e2:
|
|
132
|
+
logger.error(
|
|
133
|
+
"Importance sampling failed even with psir importance sampling. "
|
|
134
|
+
"This might indicate invalid probability weights or insufficient valid samples."
|
|
135
|
+
)
|
|
136
|
+
raise ValueError(
|
|
137
|
+
"Importance sampling failed for both with and without replacement"
|
|
138
|
+
) from e2
|
|
139
|
+
raise
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from enum import Enum, auto
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from scipy.optimize import minimize
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(slots=True)
|
|
16
|
+
class LBFGSHistory:
|
|
17
|
+
"""History of LBFGS iterations."""
|
|
18
|
+
|
|
19
|
+
x: NDArray[np.float64]
|
|
20
|
+
g: NDArray[np.float64]
|
|
21
|
+
count: int
|
|
22
|
+
|
|
23
|
+
def __post_init__(self):
|
|
24
|
+
self.x = np.ascontiguousarray(self.x, dtype=np.float64)
|
|
25
|
+
self.g = np.ascontiguousarray(self.g, dtype=np.float64)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class LBFGSHistoryManager:
|
|
30
|
+
"""manages and stores the history of lbfgs optimisation iterations.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
value_grad_fn : Callable
|
|
35
|
+
function that returns tuple of (value, gradient) given input x
|
|
36
|
+
x0 : NDArray
|
|
37
|
+
initial position
|
|
38
|
+
maxiter : int
|
|
39
|
+
maximum number of iterations to store
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
value_grad_fn: Callable[[NDArray[np.float64]], tuple[np.float64, NDArray[np.float64]]]
|
|
43
|
+
x0: NDArray[np.float64]
|
|
44
|
+
maxiter: int
|
|
45
|
+
x_history: NDArray[np.float64] = field(init=False)
|
|
46
|
+
g_history: NDArray[np.float64] = field(init=False)
|
|
47
|
+
count: int = field(init=False)
|
|
48
|
+
|
|
49
|
+
def __post_init__(self) -> None:
|
|
50
|
+
self.x_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
|
|
51
|
+
self.g_history = np.empty((self.maxiter + 1, self.x0.shape[0]), dtype=np.float64)
|
|
52
|
+
self.count = 0
|
|
53
|
+
|
|
54
|
+
value, grad = self.value_grad_fn(self.x0)
|
|
55
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value):
|
|
56
|
+
self.add_entry(self.x0, grad)
|
|
57
|
+
|
|
58
|
+
def add_entry(self, x: NDArray[np.float64], g: NDArray[np.float64]) -> None:
|
|
59
|
+
"""adds new position and gradient to history.
|
|
60
|
+
|
|
61
|
+
Parameters
|
|
62
|
+
----------
|
|
63
|
+
x : NDArray
|
|
64
|
+
position vector
|
|
65
|
+
g : NDArray
|
|
66
|
+
gradient vector
|
|
67
|
+
"""
|
|
68
|
+
self.x_history[self.count] = x
|
|
69
|
+
self.g_history[self.count] = g
|
|
70
|
+
self.count += 1
|
|
71
|
+
|
|
72
|
+
def get_history(self) -> LBFGSHistory:
|
|
73
|
+
"""returns history of optimisation iterations."""
|
|
74
|
+
return LBFGSHistory(
|
|
75
|
+
x=self.x_history[: self.count], g=self.g_history[: self.count], count=self.count
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def __call__(self, x: NDArray[np.float64]) -> None:
|
|
79
|
+
value, grad = self.value_grad_fn(x)
|
|
80
|
+
if np.all(np.isfinite(grad)) and np.isfinite(value) and self.count < self.maxiter + 1:
|
|
81
|
+
self.add_entry(x, grad)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class LBFGSStatus(Enum):
|
|
85
|
+
CONVERGED = auto()
|
|
86
|
+
MAX_ITER_REACHED = auto()
|
|
87
|
+
DIVERGED = auto()
|
|
88
|
+
# Statuses that lead to Exceptions:
|
|
89
|
+
INIT_FAILED = auto()
|
|
90
|
+
LBFGS_FAILED = auto()
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class LBFGSException(Exception):
|
|
94
|
+
DEFAULT_MESSAGE = "LBFGS failed."
|
|
95
|
+
|
|
96
|
+
def __init__(self, message=None, status: LBFGSStatus = LBFGSStatus.LBFGS_FAILED):
|
|
97
|
+
super().__init__(message or self.DEFAULT_MESSAGE)
|
|
98
|
+
self.status = status
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class LBFGSInitFailed(LBFGSException):
|
|
102
|
+
DEFAULT_MESSAGE = "LBFGS failed to initialise."
|
|
103
|
+
|
|
104
|
+
def __init__(self, message=None):
|
|
105
|
+
super().__init__(message or self.DEFAULT_MESSAGE, LBFGSStatus.INIT_FAILED)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LBFGS:
|
|
109
|
+
"""L-BFGS optimizer wrapper around scipy's implementation.
|
|
110
|
+
|
|
111
|
+
Parameters
|
|
112
|
+
----------
|
|
113
|
+
value_grad_fn : Callable
|
|
114
|
+
function that returns tuple of (value, gradient) given input x
|
|
115
|
+
maxcor : int
|
|
116
|
+
maximum number of variable metric corrections
|
|
117
|
+
maxiter : int, optional
|
|
118
|
+
maximum number of iterations, defaults to 1000
|
|
119
|
+
ftol : float, optional
|
|
120
|
+
function tolerance for convergence, defaults to 1e-5
|
|
121
|
+
gtol : float, optional
|
|
122
|
+
gradient tolerance for convergence, defaults to 1e-8
|
|
123
|
+
maxls : int, optional
|
|
124
|
+
maximum number of line search steps, defaults to 1000
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self, value_grad_fn, maxcor, maxiter=1000, ftol=1e-5, gtol=1e-8, maxls=1000
|
|
129
|
+
) -> None:
|
|
130
|
+
self.value_grad_fn = value_grad_fn
|
|
131
|
+
self.maxcor = maxcor
|
|
132
|
+
self.maxiter = maxiter
|
|
133
|
+
self.ftol = ftol
|
|
134
|
+
self.gtol = gtol
|
|
135
|
+
self.maxls = maxls
|
|
136
|
+
|
|
137
|
+
def minimize(self, x0) -> tuple[NDArray, NDArray, int, LBFGSStatus]:
|
|
138
|
+
"""minimizes objective function starting from initial position.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
x0 : array_like
|
|
143
|
+
initial position
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
x : NDArray
|
|
148
|
+
history of positions
|
|
149
|
+
g : NDArray
|
|
150
|
+
history of gradients
|
|
151
|
+
count : int
|
|
152
|
+
number of iterations
|
|
153
|
+
status : LBFGSStatus
|
|
154
|
+
final status of optimisation
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
x0 = np.array(x0, dtype=np.float64)
|
|
158
|
+
|
|
159
|
+
history_manager = LBFGSHistoryManager(
|
|
160
|
+
value_grad_fn=self.value_grad_fn, x0=x0, maxiter=self.maxiter
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
result = minimize(
|
|
164
|
+
self.value_grad_fn,
|
|
165
|
+
x0,
|
|
166
|
+
method="L-BFGS-B",
|
|
167
|
+
jac=True,
|
|
168
|
+
callback=history_manager,
|
|
169
|
+
options={
|
|
170
|
+
"maxcor": self.maxcor,
|
|
171
|
+
"maxiter": self.maxiter,
|
|
172
|
+
"ftol": self.ftol,
|
|
173
|
+
"gtol": self.gtol,
|
|
174
|
+
"maxls": self.maxls,
|
|
175
|
+
},
|
|
176
|
+
)
|
|
177
|
+
history = history_manager.get_history()
|
|
178
|
+
|
|
179
|
+
# warnings and suggestions for LBFGSStatus are displayed at the end
|
|
180
|
+
if result.status == 1:
|
|
181
|
+
lbfgs_status = LBFGSStatus.MAX_ITER_REACHED
|
|
182
|
+
elif (result.status == 2) or (history.count <= 1):
|
|
183
|
+
if result.nit <= 1:
|
|
184
|
+
lbfgs_status = LBFGSStatus.INIT_FAILED
|
|
185
|
+
elif result.fun == np.inf:
|
|
186
|
+
lbfgs_status = LBFGSStatus.DIVERGED
|
|
187
|
+
else:
|
|
188
|
+
lbfgs_status = LBFGSStatus.CONVERGED
|
|
189
|
+
|
|
190
|
+
return history.x, history.g, history.count, lbfgs_status
|