pymc-extras 0.2.6__tar.gz → 0.3.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/.readthedocs.yaml +1 -1
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/PKG-INFO +6 -4
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/_version.py +2 -2
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/conda-envs/environment-test.yml +2 -1
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/api_reference.rst +26 -0
- pymc_extras-0.3.1/pymc_extras/deserialize.py +224 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/__init__.py +2 -2
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/fit.py +1 -1
- pymc_extras-0.3.1/pymc_extras/inference/laplace_approx/find_map.py +347 -0
- pymc_extras-0.3.1/pymc_extras/inference/laplace_approx/idata.py +392 -0
- pymc_extras-0.3.1/pymc_extras/inference/laplace_approx/laplace.py +451 -0
- pymc_extras-0.3.1/pymc_extras/inference/laplace_approx/scipy_interface.py +242 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/pathfinder/pathfinder.py +2 -2
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/linearmodel.py +3 -1
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model/marginal/graph_analysis.py +4 -0
- pymc_extras-0.3.1/pymc_extras/prior.py +1388 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/core/statespace.py +78 -52
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/filters/kalman_smoother.py +1 -1
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pyproject.toml +5 -3
- pymc_extras-0.3.1/tests/inference/laplace_approx/test_find_map.py +322 -0
- pymc_extras-0.3.1/tests/inference/laplace_approx/test_idata.py +297 -0
- pymc_extras-0.3.1/tests/inference/laplace_approx/test_laplace.py +329 -0
- pymc_extras-0.3.1/tests/inference/laplace_approx/test_scipy_interface.py +118 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/model/marginal/test_graph_analysis.py +8 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/core/test_statespace.py +96 -5
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/core/test_statespace_JAX.py +5 -2
- pymc_extras-0.3.1/tests/statespace/models/__init__.py +0 -0
- pymc_extras-0.3.1/tests/statespace/utils/__init__.py +0 -0
- pymc_extras-0.3.1/tests/test_deserialize.py +59 -0
- pymc_extras-0.3.1/tests/test_prior.py +1205 -0
- pymc_extras-0.3.1/tests/utils.py +0 -0
- pymc_extras-0.2.6/pymc_extras/inference/find_map.py +0 -496
- pymc_extras-0.2.6/pymc_extras/inference/laplace.py +0 -583
- pymc_extras-0.2.6/pymc_extras/utils/pivoted_cholesky.py +0 -69
- pymc_extras-0.2.6/tests/test_find_map.py +0 -158
- pymc_extras-0.2.6/tests/test_laplace.py +0 -281
- pymc_extras-0.2.6/tests/test_pivoted_cholesky.py +0 -24
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/.gitignore +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/.gitpod.yml +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/.pre-commit-config.yaml +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/CONTRIBUTING.md +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/LICENSE +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/README.md +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/codecov.yml +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/.nojekyll +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/Makefile +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/_templates/autosummary/base.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/_templates/autosummary/class.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/conf.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/index.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/make.bat +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/statespace/core.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/statespace/filters.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/statespace/models/structural.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/docs/statespace/models.rst +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/continuous.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/histogram_utils.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/timeseries.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/transforms/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/distributions/transforms/partial_order.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.2.6/pymc_extras/model → pymc_extras-0.3.1/pymc_extras/inference/laplace_approx}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/pathfinder/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/inference/smc/sampling.py +0 -0
- {pymc_extras-0.2.6/pymc_extras/model/marginal → pymc_extras-0.3.1/pymc_extras/model}/__init__.py +0 -0
- {pymc_extras-0.2.6/pymc_extras/model/transforms → pymc_extras-0.3.1/pymc_extras/model/marginal}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model/marginal/distributions.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model/marginal/marginal_model.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.2.6/pymc_extras/preprocessing → pymc_extras-0.3.1/pymc_extras/model/transforms}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/model_builder.py +0 -0
- {pymc_extras-0.2.6/pymc_extras/statespace/utils → pymc_extras-0.3.1/pymc_extras/preprocessing}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/core/compile.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/filters/distributions.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/filters/utilities.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/ETS.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/SARIMAX.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/VARMAX.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/structural.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/models/utilities.py +0 -0
- {pymc_extras-0.2.6/tests/model → pymc_extras-0.3.1/pymc_extras/statespace/utils}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/statespace/utils/data_tools.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/utils/model_equivalence.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/utils/prior.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/pymc_extras/utils/spline.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/conftest.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/test_continuous.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/test_discrete.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/distributions/test_transform.py +0 -0
- {pymc_extras-0.2.6/tests/model/marginal → pymc_extras-0.3.1/tests/inference}/__init__.py +0 -0
- {pymc_extras-0.2.6/tests/statespace → pymc_extras-0.3.1/tests/inference/laplace_approx}/__init__.py +0 -0
- {pymc_extras-0.2.6/tests/statespace/core → pymc_extras-0.3.1/tests/model}/__init__.py +0 -0
- {pymc_extras-0.2.6/tests/statespace/filters → pymc_extras-0.3.1/tests/model/marginal}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/model/marginal/test_distributions.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/model/marginal/test_marginal_model.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/model/transforms/test_autoreparam.py +0 -0
- {pymc_extras-0.2.6/tests/statespace/models → pymc_extras-0.3.1/tests/statespace}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/_data/airpass.csv +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/_data/airpassangers.csv +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/_data/nile.csv +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
- {pymc_extras-0.2.6/tests/statespace/utils → pymc_extras-0.3.1/tests/statespace/core}/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/core/test_representation.py +0 -0
- /pymc_extras-0.2.6/tests/utils.py → /pymc_extras-0.3.1/tests/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/filters/test_distributions.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/filters/test_kalman_filter.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/models/test_ETS.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/models/test_SARIMAX.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/models/test_VARMAX.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/models/test_structural.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/shared_fixtures.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/statsmodel_local_level.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/test_utilities.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/statespace/utils/test_coord_assignment.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_histogram_approximation.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_pathfinder.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_printing.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_prior_from_trace.py +0 -0
- {pymc_extras-0.2.6 → pymc_extras-0.3.1}/tests/test_splines.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.1
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Project-URL: Documentation, https://pymc-extras.readthedocs.io/
|
|
6
6
|
Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
|
|
@@ -226,14 +226,14 @@ Classifier: License :: OSI Approved :: Apache Software License
|
|
|
226
226
|
Classifier: Operating System :: OS Independent
|
|
227
227
|
Classifier: Programming Language :: Python
|
|
228
228
|
Classifier: Programming Language :: Python :: 3
|
|
229
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
230
229
|
Classifier: Programming Language :: Python :: 3.11
|
|
231
230
|
Classifier: Programming Language :: Python :: 3.12
|
|
232
231
|
Classifier: Programming Language :: Python :: 3.13
|
|
233
232
|
Classifier: Topic :: Scientific/Engineering
|
|
234
233
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
235
|
-
Requires-Python: >=3.
|
|
236
|
-
Requires-Dist: better-optimize>=0.1.
|
|
234
|
+
Requires-Python: >=3.11
|
|
235
|
+
Requires-Dist: better-optimize>=0.1.4
|
|
236
|
+
Requires-Dist: pydantic>=2.0.0
|
|
237
237
|
Requires-Dist: pymc>=5.21.1
|
|
238
238
|
Requires-Dist: scikit-learn
|
|
239
239
|
Provides-Extra: complete
|
|
@@ -245,6 +245,8 @@ Requires-Dist: xhistogram; extra == 'dask-histogram'
|
|
|
245
245
|
Provides-Extra: dev
|
|
246
246
|
Requires-Dist: blackjax; extra == 'dev'
|
|
247
247
|
Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
|
|
248
|
+
Requires-Dist: preliz>=0.5.0; extra == 'dev'
|
|
249
|
+
Requires-Dist: pytest-mock; extra == 'dev'
|
|
248
250
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
249
251
|
Requires-Dist: statsmodels; extra == 'dev'
|
|
250
252
|
Provides-Extra: docs
|
|
@@ -46,6 +46,32 @@ Distributions
|
|
|
46
46
|
Skellam
|
|
47
47
|
histogram_approximation
|
|
48
48
|
|
|
49
|
+
Prior
|
|
50
|
+
=====
|
|
51
|
+
|
|
52
|
+
.. currentmodule:: pymc_extras.prior
|
|
53
|
+
.. autosummary::
|
|
54
|
+
:toctree: generated/
|
|
55
|
+
|
|
56
|
+
create_dim_handler
|
|
57
|
+
handle_dims
|
|
58
|
+
Prior
|
|
59
|
+
VariableFactory
|
|
60
|
+
sample_prior
|
|
61
|
+
Censored
|
|
62
|
+
Scaled
|
|
63
|
+
|
|
64
|
+
Deserialize
|
|
65
|
+
===========
|
|
66
|
+
|
|
67
|
+
.. currentmodule:: pymc_extras.deserialize
|
|
68
|
+
.. autosummary::
|
|
69
|
+
:toctree: generated/
|
|
70
|
+
|
|
71
|
+
deserialize
|
|
72
|
+
register_deserialization
|
|
73
|
+
Deserializer
|
|
74
|
+
|
|
49
75
|
|
|
50
76
|
Transforms
|
|
51
77
|
==========
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Deserialize dictionaries into Python objects.
|
|
2
|
+
|
|
3
|
+
This is a two step process:
|
|
4
|
+
|
|
5
|
+
1. Determine if the data is of the correct type.
|
|
6
|
+
2. Deserialize the data into a python object.
|
|
7
|
+
|
|
8
|
+
Examples
|
|
9
|
+
--------
|
|
10
|
+
Make use of the already registered deserializers:
|
|
11
|
+
|
|
12
|
+
.. code-block:: python
|
|
13
|
+
|
|
14
|
+
from pymc_extras.deserialize import deserialize
|
|
15
|
+
|
|
16
|
+
prior_class_data = {
|
|
17
|
+
"dist": "Normal",
|
|
18
|
+
"kwargs": {"mu": 0, "sigma": 1}
|
|
19
|
+
}
|
|
20
|
+
prior = deserialize(prior_class_data)
|
|
21
|
+
# Prior("Normal", mu=0, sigma=1)
|
|
22
|
+
|
|
23
|
+
Register custom class deserialization:
|
|
24
|
+
|
|
25
|
+
.. code-block:: python
|
|
26
|
+
|
|
27
|
+
from pymc_extras.deserialize import register_deserialization
|
|
28
|
+
|
|
29
|
+
class MyClass:
|
|
30
|
+
def __init__(self, value: int):
|
|
31
|
+
self.value = value
|
|
32
|
+
|
|
33
|
+
def to_dict(self) -> dict:
|
|
34
|
+
# Example of what the to_dict method might look like.
|
|
35
|
+
return {"value": self.value}
|
|
36
|
+
|
|
37
|
+
register_deserialization(
|
|
38
|
+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
39
|
+
deserialize=lambda data: MyClass(value=data["value"]),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
Deserialize data into that custom class:
|
|
43
|
+
|
|
44
|
+
.. code-block:: python
|
|
45
|
+
|
|
46
|
+
from pymc_extras.deserialize import deserialize
|
|
47
|
+
|
|
48
|
+
data = {"value": 42}
|
|
49
|
+
obj = deserialize(data)
|
|
50
|
+
assert isinstance(obj, MyClass)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
from collections.abc import Callable
|
|
56
|
+
from dataclasses import dataclass
|
|
57
|
+
from typing import Any
|
|
58
|
+
|
|
59
|
+
IsType = Callable[[Any], bool]
|
|
60
|
+
Deserialize = Callable[[Any], Any]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class Deserializer:
|
|
65
|
+
"""Object to store information required for deserialization.
|
|
66
|
+
|
|
67
|
+
All deserializers should be stored via the :func:`register_deserialization` function
|
|
68
|
+
instead of creating this object directly.
|
|
69
|
+
|
|
70
|
+
Attributes
|
|
71
|
+
----------
|
|
72
|
+
is_type : IsType
|
|
73
|
+
Function to determine if the data is of the correct type.
|
|
74
|
+
deserialize : Deserialize
|
|
75
|
+
Function to deserialize the data.
|
|
76
|
+
|
|
77
|
+
Examples
|
|
78
|
+
--------
|
|
79
|
+
.. code-block:: python
|
|
80
|
+
|
|
81
|
+
from typing import Any
|
|
82
|
+
|
|
83
|
+
class MyClass:
|
|
84
|
+
def __init__(self, value: int):
|
|
85
|
+
self.value = value
|
|
86
|
+
|
|
87
|
+
from pymc_extras.deserialize import Deserializer
|
|
88
|
+
|
|
89
|
+
def is_type(data: Any) -> bool:
|
|
90
|
+
return data.keys() == {"value"} and isinstance(data["value"], int)
|
|
91
|
+
|
|
92
|
+
def deserialize(data: dict) -> MyClass:
|
|
93
|
+
return MyClass(value=data["value"])
|
|
94
|
+
|
|
95
|
+
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
|
|
96
|
+
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
is_type: IsType
|
|
100
|
+
deserialize: Deserialize
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
DESERIALIZERS: list[Deserializer] = []
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class DeserializableError(Exception):
|
|
107
|
+
"""Error raised when data cannot be deserialized."""
|
|
108
|
+
|
|
109
|
+
def __init__(self, data: Any):
|
|
110
|
+
self.data = data
|
|
111
|
+
super().__init__(
|
|
112
|
+
f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping."
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def deserialize(data: Any) -> Any:
|
|
117
|
+
"""Deserialize a dictionary into a Python object.
|
|
118
|
+
|
|
119
|
+
Use the :func:`register_deserialization` function to add custom deserializations.
|
|
120
|
+
|
|
121
|
+
Deserialization is a two step process due to the dynamic nature of the data:
|
|
122
|
+
|
|
123
|
+
1. Determine if the data is of the correct type.
|
|
124
|
+
2. Deserialize the data into a Python object.
|
|
125
|
+
|
|
126
|
+
Each registered deserialization is checked in order until one is found that can
|
|
127
|
+
deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised.
|
|
128
|
+
|
|
129
|
+
A :class:`DeserializableError` is raised when the data fails to be deserialized
|
|
130
|
+
by any of the registered deserializers.
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
data : Any
|
|
135
|
+
The data to deserialize.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
Any
|
|
140
|
+
The deserialized object.
|
|
141
|
+
|
|
142
|
+
Raises
|
|
143
|
+
------
|
|
144
|
+
DeserializableError
|
|
145
|
+
Raised when the data doesn't match any registered deserializations
|
|
146
|
+
or fails to be deserialized.
|
|
147
|
+
|
|
148
|
+
Examples
|
|
149
|
+
--------
|
|
150
|
+
Deserialize a :class:`pymc_extras.prior.Prior` object:
|
|
151
|
+
|
|
152
|
+
.. code-block:: python
|
|
153
|
+
|
|
154
|
+
from pymc_extras.deserialize import deserialize
|
|
155
|
+
|
|
156
|
+
data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
|
|
157
|
+
prior = deserialize(data)
|
|
158
|
+
# Prior("Normal", mu=0, sigma=1)
|
|
159
|
+
|
|
160
|
+
"""
|
|
161
|
+
for mapping in DESERIALIZERS:
|
|
162
|
+
try:
|
|
163
|
+
is_type = mapping.is_type(data)
|
|
164
|
+
except Exception:
|
|
165
|
+
is_type = False
|
|
166
|
+
|
|
167
|
+
if not is_type:
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
return mapping.deserialize(data)
|
|
172
|
+
except Exception as e:
|
|
173
|
+
raise DeserializableError(data) from e
|
|
174
|
+
else:
|
|
175
|
+
raise DeserializableError(data)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
179
|
+
"""Register an arbitrary deserialization.
|
|
180
|
+
|
|
181
|
+
Use the :func:`deserialize` function to then deserialize data using all registered
|
|
182
|
+
deserialize functions.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
is_type : Callable[[Any], bool]
|
|
187
|
+
Function to determine if the data is of the correct type.
|
|
188
|
+
deserialize : Callable[[dict], Any]
|
|
189
|
+
Function to deserialize the data of that type.
|
|
190
|
+
|
|
191
|
+
Examples
|
|
192
|
+
--------
|
|
193
|
+
Register a custom class deserialization:
|
|
194
|
+
|
|
195
|
+
.. code-block:: python
|
|
196
|
+
|
|
197
|
+
from pymc_extras.deserialize import register_deserialization
|
|
198
|
+
|
|
199
|
+
class MyClass:
|
|
200
|
+
def __init__(self, value: int):
|
|
201
|
+
self.value = value
|
|
202
|
+
|
|
203
|
+
def to_dict(self) -> dict:
|
|
204
|
+
# Example of what the to_dict method might look like.
|
|
205
|
+
return {"value": self.value}
|
|
206
|
+
|
|
207
|
+
register_deserialization(
|
|
208
|
+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
209
|
+
deserialize=lambda data: MyClass(value=data["value"]),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
Use that custom class deserialization:
|
|
213
|
+
|
|
214
|
+
.. code-block:: python
|
|
215
|
+
|
|
216
|
+
from pymc_extras.deserialize import deserialize
|
|
217
|
+
|
|
218
|
+
data = {"value": 42}
|
|
219
|
+
obj = deserialize(data)
|
|
220
|
+
assert isinstance(obj, MyClass)
|
|
221
|
+
|
|
222
|
+
"""
|
|
223
|
+
mapping = Deserializer(is_type=is_type, deserialize=deserialize)
|
|
224
|
+
DESERIALIZERS.append(mapping)
|
|
@@ -12,9 +12,9 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from pymc_extras.inference.find_map import find_MAP
|
|
16
15
|
from pymc_extras.inference.fit import fit
|
|
17
|
-
from pymc_extras.inference.
|
|
16
|
+
from pymc_extras.inference.laplace_approx.find_map import find_MAP
|
|
17
|
+
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
|
|
18
18
|
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
|
|
19
19
|
|
|
20
20
|
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
|
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Literal, cast
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pymc as pm
|
|
8
|
+
|
|
9
|
+
from better_optimize import basinhopping, minimize
|
|
10
|
+
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
|
|
11
|
+
from pymc.blocking import DictToArrayBijection, RaveledVars
|
|
12
|
+
from pymc.initial_point import make_initial_point_fn
|
|
13
|
+
from pymc.model.transform.optimization import freeze_dims_and_data
|
|
14
|
+
from pymc.util import get_default_varnames
|
|
15
|
+
from pytensor.tensor import TensorVariable
|
|
16
|
+
from scipy.optimize import OptimizeResult
|
|
17
|
+
|
|
18
|
+
from pymc_extras.inference.laplace_approx.idata import (
|
|
19
|
+
add_data_to_inference_data,
|
|
20
|
+
add_fit_to_inference_data,
|
|
21
|
+
add_optimizer_result_to_inference_data,
|
|
22
|
+
map_results_to_inference_data,
|
|
23
|
+
)
|
|
24
|
+
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
25
|
+
GradientBackend,
|
|
26
|
+
scipy_optimize_funcs_from_loss,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
_log = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
|
|
33
|
+
method_info = MINIMIZE_MODE_KWARGS[method].copy()
|
|
34
|
+
|
|
35
|
+
if use_hess and use_hessp:
|
|
36
|
+
_log.warning(
|
|
37
|
+
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
|
|
38
|
+
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
|
|
39
|
+
'Setting "use_hess" to False.'
|
|
40
|
+
)
|
|
41
|
+
use_hess = False
|
|
42
|
+
|
|
43
|
+
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
|
|
44
|
+
|
|
45
|
+
if use_hessp is not None and use_hess is None:
|
|
46
|
+
use_hess = not use_hessp
|
|
47
|
+
|
|
48
|
+
elif use_hess is not None and use_hessp is None:
|
|
49
|
+
use_hessp = not use_hess
|
|
50
|
+
|
|
51
|
+
elif use_hessp is None and use_hess is None:
|
|
52
|
+
use_hessp = method_info["uses_hessp"]
|
|
53
|
+
use_hess = method_info["uses_hess"]
|
|
54
|
+
if use_hessp and use_hess:
|
|
55
|
+
# If a method could use either hess or hessp, we default to using hessp
|
|
56
|
+
use_hess = False
|
|
57
|
+
|
|
58
|
+
return use_grad, use_hess, use_hessp
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
|
|
62
|
+
"""
|
|
63
|
+
Compute the nearest positive semi-definite matrix to a given matrix.
|
|
64
|
+
|
|
65
|
+
This function takes a square matrix and returns the nearest positive semi-definite matrix using
|
|
66
|
+
eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms
|
|
67
|
+
of the Frobenius norm.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
A : np.ndarray
|
|
72
|
+
Input square matrix.
|
|
73
|
+
|
|
74
|
+
Returns
|
|
75
|
+
-------
|
|
76
|
+
np.ndarray
|
|
77
|
+
The nearest positive semi-definite matrix to the input matrix.
|
|
78
|
+
"""
|
|
79
|
+
C = (A + A.T) / 2
|
|
80
|
+
eigval, eigvec = np.linalg.eigh(C)
|
|
81
|
+
eigval[eigval < 0] = 0
|
|
82
|
+
|
|
83
|
+
return eigvec @ np.diag(eigval) @ eigvec.T
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _make_initial_point(model, initvals=None, random_seed=None, jitter_rvs=None):
|
|
87
|
+
jitter_rvs = [] if jitter_rvs is None else jitter_rvs
|
|
88
|
+
|
|
89
|
+
ipfn = make_initial_point_fn(
|
|
90
|
+
model=model,
|
|
91
|
+
jitter_rvs=set(jitter_rvs),
|
|
92
|
+
return_transformed=True,
|
|
93
|
+
overrides=initvals,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
start_dict = ipfn(random_seed)
|
|
97
|
+
vars_dict = {var.name: var for var in model.continuous_value_vars}
|
|
98
|
+
initial_params = DictToArrayBijection.map(
|
|
99
|
+
{var_name: value for var_name, value in start_dict.items() if var_name in vars_dict}
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return initial_params
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _compute_inverse_hessian(
|
|
106
|
+
optimizer_result: OptimizeResult | None,
|
|
107
|
+
optimal_point: np.ndarray | None,
|
|
108
|
+
f_fused: Callable | None,
|
|
109
|
+
f_hessp: Callable | None,
|
|
110
|
+
use_hess: bool,
|
|
111
|
+
method: minimize_method | Literal["BFGS", "L-BFGS-B"],
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Compute the Hessian matrix or its inverse based on the optimization result and the method used.
|
|
115
|
+
|
|
116
|
+
Downstream functions (e.g. laplace approximation) will need the inverse Hessian matrix. This function computes it
|
|
117
|
+
in the cheapest way possible, depending on the optimization method used and the available compiled functions.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
optimizer_result: OptimizeResult, optional
|
|
122
|
+
The result of the optimization, containing the optimized parameters and possibly an approximate inverse Hessian.
|
|
123
|
+
optimal_point: np.ndarray, optional
|
|
124
|
+
The optimal point found by the optimizer, used to compute the Hessian if necessary. If not provided, it will be
|
|
125
|
+
extracted from the optimizer result.
|
|
126
|
+
f_fused: callable, optional
|
|
127
|
+
The compiled function representing the loss and possibly its gradient and Hessian.
|
|
128
|
+
f_hessp: callable, optional
|
|
129
|
+
The compiled function for Hessian-vector products, if available.
|
|
130
|
+
use_hess: bool
|
|
131
|
+
Whether the Hessian matrix was used in the optimization.
|
|
132
|
+
method: minimize_method
|
|
133
|
+
The optimization method used, which determines how the Hessian is computed.
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
H_inv: np.ndarray
|
|
138
|
+
The inverse Hessian matrix, computed based on the optimization method and available functions.
|
|
139
|
+
"""
|
|
140
|
+
if optimal_point is None and optimizer_result is None:
|
|
141
|
+
raise ValueError("At least one of `optimal_point` or `optimizer_result` must be provided.")
|
|
142
|
+
|
|
143
|
+
x_star = optimizer_result.x if optimizer_result is not None else optimal_point
|
|
144
|
+
n_vars = len(x_star)
|
|
145
|
+
|
|
146
|
+
if method == "BFGS" and optimizer_result is not None:
|
|
147
|
+
# If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than
|
|
148
|
+
# re-computing something
|
|
149
|
+
if hasattr(optimizer_result, "lowest_optimization_result"):
|
|
150
|
+
# We did basinhopping, need to get the inner optimizer results
|
|
151
|
+
H_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None)
|
|
152
|
+
else:
|
|
153
|
+
H_inv = getattr(optimizer_result, "hess_inv", None)
|
|
154
|
+
|
|
155
|
+
elif method == "L-BFGS-B" and optimizer_result is not None:
|
|
156
|
+
# Here we will have a LinearOperator representing the inverse Hessian-Vector product.
|
|
157
|
+
if hasattr(optimizer_result, "lowest_optimization_result"):
|
|
158
|
+
# We did basinhopping, need to get the inner optimizer results
|
|
159
|
+
f_hessp_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None)
|
|
160
|
+
else:
|
|
161
|
+
f_hessp_inv = getattr(optimizer_result, "hess_inv", None)
|
|
162
|
+
|
|
163
|
+
if f_hessp_inv is not None:
|
|
164
|
+
basis = np.eye(n_vars)
|
|
165
|
+
H_inv = np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1)
|
|
166
|
+
else:
|
|
167
|
+
H_inv = None
|
|
168
|
+
|
|
169
|
+
elif f_hessp is not None:
|
|
170
|
+
# In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from
|
|
171
|
+
# the hessp function, using euclidian basis vector.
|
|
172
|
+
basis = np.eye(n_vars)
|
|
173
|
+
H = np.stack([f_hessp(x_star, basis[:, i]) for i in range(n_vars)], axis=-1)
|
|
174
|
+
H_inv = np.linalg.inv(get_nearest_psd(H))
|
|
175
|
+
|
|
176
|
+
elif use_hess and f_fused is not None:
|
|
177
|
+
# If we compiled a hessian function, just use it
|
|
178
|
+
_, _, H = f_fused(x_star)
|
|
179
|
+
H_inv = np.linalg.inv(get_nearest_psd(H))
|
|
180
|
+
|
|
181
|
+
else:
|
|
182
|
+
H_inv = None
|
|
183
|
+
|
|
184
|
+
return H_inv
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def find_MAP(
|
|
188
|
+
method: minimize_method | Literal["basinhopping"] = "L-BFGS-B",
|
|
189
|
+
*,
|
|
190
|
+
model: pm.Model | None = None,
|
|
191
|
+
use_grad: bool | None = None,
|
|
192
|
+
use_hessp: bool | None = None,
|
|
193
|
+
use_hess: bool | None = None,
|
|
194
|
+
initvals: dict | None = None,
|
|
195
|
+
random_seed: int | np.random.Generator | None = None,
|
|
196
|
+
jitter_rvs: list[TensorVariable] | None = None,
|
|
197
|
+
progressbar: bool = True,
|
|
198
|
+
include_transformed: bool = True,
|
|
199
|
+
gradient_backend: GradientBackend = "pytensor",
|
|
200
|
+
compile_kwargs: dict | None = None,
|
|
201
|
+
**optimizer_kwargs,
|
|
202
|
+
) -> (
|
|
203
|
+
dict[str, np.ndarray]
|
|
204
|
+
| tuple[dict[str, np.ndarray], np.ndarray]
|
|
205
|
+
| tuple[dict[str, np.ndarray], OptimizeResult]
|
|
206
|
+
| tuple[dict[str, np.ndarray], OptimizeResult, np.ndarray]
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
model : pm.Model
|
|
214
|
+
The PyMC model to be fit. If None, the current model context is used.
|
|
215
|
+
method : str
|
|
216
|
+
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
|
|
217
|
+
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
|
|
218
|
+
|
|
219
|
+
See scipy.optimize.minimize documentation for details.
|
|
220
|
+
use_grad : bool | None, optional
|
|
221
|
+
Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on
|
|
222
|
+
the ``method``.
|
|
223
|
+
use_hessp : bool | None, optional
|
|
224
|
+
Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on
|
|
225
|
+
the ``method``.
|
|
226
|
+
use_hess : bool | None, optional
|
|
227
|
+
Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on
|
|
228
|
+
the ``method``.
|
|
229
|
+
initvals : None | dict, optional
|
|
230
|
+
Initial values for the model parameters, as str:ndarray key-value pairs. Partial initialization is permitted.
|
|
231
|
+
If None, the model's default initial values are used.
|
|
232
|
+
random_seed : None | int | np.random.Generator, optional
|
|
233
|
+
Seed for the random number generator or a numpy Generator for reproducibility
|
|
234
|
+
jitter_rvs : list of TensorVariables, optional
|
|
235
|
+
Variables whose initial values should be jittered. If None, all variables are jittered.
|
|
236
|
+
progressbar : bool, optional
|
|
237
|
+
Whether to display a progress bar during optimization. Defaults to True.
|
|
238
|
+
include_transformed: bool, optional
|
|
239
|
+
Whether to include transformed variable values in the returned dictionary. Defaults to True.
|
|
240
|
+
gradient_backend: str, default "pytensor"
|
|
241
|
+
Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
|
|
242
|
+
compile_kwargs: dict, optional
|
|
243
|
+
Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
|
|
244
|
+
**optimizer_kwargs
|
|
245
|
+
Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
|
|
246
|
+
``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
|
|
247
|
+
``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
map_result: az.InferenceData
|
|
252
|
+
Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed
|
|
253
|
+
latent variables, and optimizer results.
|
|
254
|
+
"""
|
|
255
|
+
model = pm.modelcontext(model) if model is None else model
|
|
256
|
+
frozen_model = freeze_dims_and_data(model)
|
|
257
|
+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
|
|
258
|
+
|
|
259
|
+
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
|
|
260
|
+
|
|
261
|
+
do_basinhopping = method == "basinhopping"
|
|
262
|
+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
263
|
+
|
|
264
|
+
if do_basinhopping:
|
|
265
|
+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
266
|
+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
267
|
+
# if one isn't provided.
|
|
268
|
+
|
|
269
|
+
method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
270
|
+
minimizer_kwargs["method"] = method
|
|
271
|
+
|
|
272
|
+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
273
|
+
method, use_grad, use_hess, use_hessp
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
277
|
+
loss=-frozen_model.logp(),
|
|
278
|
+
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
|
|
279
|
+
initial_point_dict=DictToArrayBijection.rmap(initial_params),
|
|
280
|
+
use_grad=use_grad,
|
|
281
|
+
use_hess=use_hess,
|
|
282
|
+
use_hessp=use_hessp,
|
|
283
|
+
gradient_backend=gradient_backend,
|
|
284
|
+
compile_kwargs=compile_kwargs,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
args = optimizer_kwargs.pop("args", ())
|
|
288
|
+
|
|
289
|
+
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
|
|
290
|
+
# if so. That is why the jac argument is not passed here in either branch.
|
|
291
|
+
|
|
292
|
+
if do_basinhopping:
|
|
293
|
+
if "args" not in minimizer_kwargs:
|
|
294
|
+
minimizer_kwargs["args"] = args
|
|
295
|
+
if "hessp" not in minimizer_kwargs:
|
|
296
|
+
minimizer_kwargs["hessp"] = f_hessp
|
|
297
|
+
if "method" not in minimizer_kwargs:
|
|
298
|
+
minimizer_kwargs["method"] = method
|
|
299
|
+
|
|
300
|
+
optimizer_result = basinhopping(
|
|
301
|
+
func=f_fused,
|
|
302
|
+
x0=cast(np.ndarray[float], initial_params.data),
|
|
303
|
+
progressbar=progressbar,
|
|
304
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
305
|
+
**optimizer_kwargs,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
else:
|
|
309
|
+
optimizer_result = minimize(
|
|
310
|
+
f=f_fused,
|
|
311
|
+
x0=cast(np.ndarray[float], initial_params.data),
|
|
312
|
+
args=args,
|
|
313
|
+
hessp=f_hessp,
|
|
314
|
+
progressbar=progressbar,
|
|
315
|
+
method=method,
|
|
316
|
+
**optimizer_kwargs,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
H_inv = _compute_inverse_hessian(
|
|
320
|
+
optimizer_result=optimizer_result,
|
|
321
|
+
optimal_point=None,
|
|
322
|
+
f_fused=f_fused,
|
|
323
|
+
f_hessp=f_hessp,
|
|
324
|
+
use_hess=use_hess,
|
|
325
|
+
method=method,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
|
|
329
|
+
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
|
|
330
|
+
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
|
|
331
|
+
DictToArrayBijection.rmap(raveled_optimized)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
optimized_point = {
|
|
335
|
+
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
idata = map_results_to_inference_data(optimized_point, frozen_model)
|
|
339
|
+
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
|
|
340
|
+
idata = add_optimizer_result_to_inference_data(
|
|
341
|
+
idata, optimizer_result, method, raveled_optimized, model
|
|
342
|
+
)
|
|
343
|
+
idata = add_data_to_inference_data(
|
|
344
|
+
idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
return idata
|