pymc-extras 0.5.0__tar.gz → 0.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pymc_extras-0.7.0/CONTRIBUTING.md +24 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/PKG-INFO +4 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/_version.py +2 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/conda-envs/environment-test.yml +3 -3
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/api_reference.rst +1 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/models.rst +2 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/deserialize.py +10 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/continuous.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/histogram_utils.py +6 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/timeseries.py +14 -12
- pymc_extras-0.7.0/pymc_extras/inference/dadvi/dadvi.py +282 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/find_map.py +16 -39
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/idata.py +22 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/laplace.py +196 -151
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
- pymc_extras-0.7.0/pymc_extras/inference/pathfinder/idata.py +517 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/pathfinder.py +71 -12
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/smc/sampling.py +2 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/distributions.py +4 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/graph_analysis.py +2 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/marginal_model.py +12 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model_builder.py +9 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/prior.py +203 -8
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/compile.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/statespace.py +2 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/distributions.py +15 -13
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/kalman_filter.py +24 -22
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/kalman_smoother.py +3 -5
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/utilities.py +2 -5
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/DFM.py +12 -27
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/ETS.py +190 -198
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/SARIMAX.py +5 -17
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/VARMAX.py +15 -67
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/regression.py +4 -26
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/utilities.py +7 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/model_equivalence.py +2 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/prior.py +10 -14
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/spline.py +4 -10
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pyproject.toml +19 -15
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_continuous.py +4 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_discrete.py +8 -5
- pymc_extras-0.7.0/tests/inference/dadvi/test_dadvi.py +177 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_laplace.py +33 -21
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_distributions.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_graph_analysis.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_marginal_model.py +21 -8
- pymc_extras-0.7.0/tests/pathfinder/test_idata.py +489 -0
- {pymc_extras-0.5.0/tests → pymc_extras-0.7.0/tests/pathfinder}/test_pathfinder.py +14 -15
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_statespace.py +3 -5
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_statespace_JAX.py +9 -9
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/filters/test_distributions.py +2 -2
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/filters/test_kalman_filter.py +47 -42
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_autoregressive.py +9 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_cycle.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_measurement_error.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_seasonality.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_DFM.py +6 -13
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_ETS.py +14 -10
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_SARIMAX.py +11 -10
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_VARMAX.py +108 -197
- pymc_extras-0.7.0/tests/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_histogram_approximation.py +1 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_splines.py +17 -1
- pymc_extras-0.7.0/tests/utils.py +0 -0
- pymc_extras-0.5.0/CONTRIBUTING.md +0 -3
- pymc_extras-0.5.0/pymc_extras/inference/dadvi/dadvi.py +0 -261
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.gitignore +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.gitpod.yml +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.pre-commit-config.yaml +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.readthedocs.yaml +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/CODE_OF_CONDUCT.md +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/LICENSE +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/README.md +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/codecov.yml +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/.nojekyll +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/Makefile +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/_templates/autosummary/base.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/_templates/autosummary/class.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/conf.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/index.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/make.bat +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/core.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/filters.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/models/structural.rst +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/discrete.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/multivariate/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/transforms/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/transforms/partial_order.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/gp/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/gp/latent_approx.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/dadvi/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/fit.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/smc/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/linearmodel.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/model_api.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/transforms/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/transforms/autoreparam.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/preprocessing/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/preprocessing/standard_scaler.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/printing.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/representation.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/cycle.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/level_trend.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/measurement_error.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/seasonality.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/core.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/utils.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/constants.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/coord_tools.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/data_tools.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/linear_cg.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/conftest.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_discrete_markov_chain.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_multivariate.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_transform.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/inference/laplace_approx → pymc_extras-0.7.0/tests/inference/dadvi}/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/model → pymc_extras-0.7.0/tests/inference/laplace_approx}/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_find_map.py +1 -1
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_idata.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
- {pymc_extras-0.5.0/tests/model/marginal → pymc_extras-0.7.0/tests/model}/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/statespace → pymc_extras-0.7.0/tests/model/marginal}/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/test_model_api.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/transforms/test_autoreparam.py +0 -0
- {pymc_extras-0.5.0/tests/statespace/core → pymc_extras-0.7.0/tests/pathfinder}/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/statespace/filters → pymc_extras-0.7.0/tests/statespace}/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/airpass.csv +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/airpassangers.csv +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/nile.csv +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
- {pymc_extras-0.5.0/tests/statespace/models → pymc_extras-0.7.0/tests/statespace/core}/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_representation.py +0 -0
- {pymc_extras-0.5.0/tests/statespace/models/structural → pymc_extras-0.7.0/tests/statespace/filters}/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/statespace/models/structural/components → pymc_extras-0.7.0/tests/statespace/models}/__init__.py +0 -0
- {pymc_extras-0.5.0/tests/statespace/utils → pymc_extras-0.7.0/tests/statespace/models/structural}/__init__.py +0 -0
- /pymc_extras-0.5.0/tests/utils.py → /pymc_extras-0.7.0/tests/statespace/models/structural/components/__init__.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_level_trend.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_regression.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/conftest.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/test_against_statsmodels.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/test_core.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_utilities.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/shared_fixtures.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/statsmodel_local_level.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/test_utilities.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/utils/test_coord_assignment.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_blackjax_smc.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_deserialize.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_linearmodel.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_model_builder.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_printing.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_prior.py +0 -0
- {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_prior_from_trace.py +0 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Contributing guide
|
|
2
|
+
|
|
3
|
+
Page in construction, for now go to https://github.com/pymc-devs/pymc-extras#questions.
|
|
4
|
+
|
|
5
|
+
## Building the documentation
|
|
6
|
+
|
|
7
|
+
To build the documentation locally, you need to install the necessary
|
|
8
|
+
dependencies and then use `make` to build the HTML files.
|
|
9
|
+
|
|
10
|
+
First, install the package with the optional documentation dependencies:
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install ".[docs]"
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Then, navigate to the `docs` directory and run `make html`:
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
cd docs
|
|
20
|
+
make html
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
The generated HTML files will be in the `docs/_build/html` directory. You can
|
|
24
|
+
open the `index.html` file in that directory to view the documentation.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pymc-extras
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.7.0
|
|
4
4
|
Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
|
|
5
5
|
Project-URL: Documentation, https://pymc-extras.readthedocs.io/
|
|
6
6
|
Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
|
|
@@ -235,8 +235,8 @@ Requires-Python: >=3.11
|
|
|
235
235
|
Requires-Dist: better-optimize>=0.1.5
|
|
236
236
|
Requires-Dist: preliz>=0.20.0
|
|
237
237
|
Requires-Dist: pydantic>=2.0.0
|
|
238
|
-
Requires-Dist: pymc>=5.
|
|
239
|
-
Requires-Dist: pytensor>=2.
|
|
238
|
+
Requires-Dist: pymc>=5.27.0
|
|
239
|
+
Requires-Dist: pytensor>=2.36.3
|
|
240
240
|
Requires-Dist: scikit-learn
|
|
241
241
|
Provides-Extra: complete
|
|
242
242
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
|
|
@@ -245,7 +245,7 @@ Provides-Extra: dask-histogram
|
|
|
245
245
|
Requires-Dist: dask[complete]<2025.1.1; extra == 'dask-histogram'
|
|
246
246
|
Requires-Dist: xhistogram; extra == 'dask-histogram'
|
|
247
247
|
Provides-Extra: dev
|
|
248
|
-
Requires-Dist: blackjax; extra == 'dev'
|
|
248
|
+
Requires-Dist: blackjax>=0.12; extra == 'dev'
|
|
249
249
|
Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
|
|
250
250
|
Requires-Dist: pytest-mock; extra == 'dev'
|
|
251
251
|
Requires-Dist: pytest>=6.0; extra == 'dev'
|
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (0,
|
|
31
|
+
__version__ = version = '0.7.0'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 7, 0)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
|
|
|
13
13
|
|
|
14
14
|
from pymc_extras.deserialize import deserialize
|
|
15
15
|
|
|
16
|
-
prior_class_data = {
|
|
17
|
-
"dist": "Normal",
|
|
18
|
-
"kwargs": {"mu": 0, "sigma": 1}
|
|
19
|
-
}
|
|
16
|
+
prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
|
|
20
17
|
prior = deserialize(prior_class_data)
|
|
21
18
|
# Prior("Normal", mu=0, sigma=1)
|
|
22
19
|
|
|
@@ -26,6 +23,7 @@ Register custom class deserialization:
|
|
|
26
23
|
|
|
27
24
|
from pymc_extras.deserialize import register_deserialization
|
|
28
25
|
|
|
26
|
+
|
|
29
27
|
class MyClass:
|
|
30
28
|
def __init__(self, value: int):
|
|
31
29
|
self.value = value
|
|
@@ -34,6 +32,7 @@ Register custom class deserialization:
|
|
|
34
32
|
# Example of what the to_dict method might look like.
|
|
35
33
|
return {"value": self.value}
|
|
36
34
|
|
|
35
|
+
|
|
37
36
|
register_deserialization(
|
|
38
37
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
39
38
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -80,18 +79,23 @@ class Deserializer:
|
|
|
80
79
|
|
|
81
80
|
from typing import Any
|
|
82
81
|
|
|
82
|
+
|
|
83
83
|
class MyClass:
|
|
84
84
|
def __init__(self, value: int):
|
|
85
85
|
self.value = value
|
|
86
86
|
|
|
87
|
+
|
|
87
88
|
from pymc_extras.deserialize import Deserializer
|
|
88
89
|
|
|
90
|
+
|
|
89
91
|
def is_type(data: Any) -> bool:
|
|
90
92
|
return data.keys() == {"value"} and isinstance(data["value"], int)
|
|
91
93
|
|
|
94
|
+
|
|
92
95
|
def deserialize(data: dict) -> MyClass:
|
|
93
96
|
return MyClass(value=data["value"])
|
|
94
97
|
|
|
98
|
+
|
|
95
99
|
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
|
|
96
100
|
|
|
97
101
|
"""
|
|
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
196
200
|
|
|
197
201
|
from pymc_extras.deserialize import register_deserialization
|
|
198
202
|
|
|
203
|
+
|
|
199
204
|
class MyClass:
|
|
200
205
|
def __init__(self, value: int):
|
|
201
206
|
self.value = value
|
|
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
|
|
|
204
209
|
# Example of what the to_dict method might look like.
|
|
205
210
|
return {"value": self.value}
|
|
206
211
|
|
|
212
|
+
|
|
207
213
|
register_deserialization(
|
|
208
214
|
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
|
|
209
215
|
deserialize=lambda data: MyClass(value=data["value"]),
|
|
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
130
130
|
... m = pm.Normal("m", dims="tests")
|
|
131
131
|
... s = pm.LogNormal("s", dims="tests")
|
|
132
132
|
... pot = pmx.distributions.histogram_approximation(
|
|
133
|
-
... "pot", pm.Normal.dist(m, s),
|
|
134
|
-
... observed=measurements, n_quantiles=50
|
|
133
|
+
... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
|
|
135
134
|
... )
|
|
136
135
|
|
|
137
136
|
For special cases like Zero Inflation in Continuous variables there is a flag.
|
|
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
|
|
|
143
142
|
... m = pm.Normal("m", dims="tests")
|
|
144
143
|
... s = pm.LogNormal("s", dims="tests")
|
|
145
144
|
... pot = pmx.distributions.histogram_approximation(
|
|
146
|
-
... "pot",
|
|
147
|
-
...
|
|
145
|
+
... "pot",
|
|
146
|
+
... pm.Normal.dist(m, s),
|
|
147
|
+
... observed=measurements,
|
|
148
|
+
... n_quantiles=50,
|
|
149
|
+
... zero_inflation=True,
|
|
148
150
|
... )
|
|
149
151
|
"""
|
|
150
152
|
try:
|
|
@@ -305,6 +305,7 @@ def R2D2M2CP(
|
|
|
305
305
|
import pymc_extras as pmx
|
|
306
306
|
import pymc as pm
|
|
307
307
|
import numpy as np
|
|
308
|
+
|
|
308
309
|
X = np.random.randn(10, 3)
|
|
309
310
|
b = np.random.randn(3)
|
|
310
311
|
y = X @ b + np.random.randn(10) * 0.04 + 5
|
|
@@ -339,7 +340,7 @@ def R2D2M2CP(
|
|
|
339
340
|
# "c" - a must have in the relation
|
|
340
341
|
variables_importance=[10, 1, 34],
|
|
341
342
|
# NOTE: try both
|
|
342
|
-
centered=True
|
|
343
|
+
centered=True,
|
|
343
344
|
)
|
|
344
345
|
# intercept prior centering should be around prior predictive mean
|
|
345
346
|
intercept = y.mean()
|
|
@@ -365,7 +366,7 @@ def R2D2M2CP(
|
|
|
365
366
|
r2_std=0.2,
|
|
366
367
|
# NOTE: if you know where a variable should go
|
|
367
368
|
# if you do not know, leave as 0.5
|
|
368
|
-
centered=False
|
|
369
|
+
centered=False,
|
|
369
370
|
)
|
|
370
371
|
# intercept prior centering should be around prior predictive mean
|
|
371
372
|
intercept = y.mean()
|
|
@@ -394,7 +395,7 @@ def R2D2M2CP(
|
|
|
394
395
|
# if you do not know, leave as 0.5
|
|
395
396
|
positive_probs=[0.8, 0.5, 0.1],
|
|
396
397
|
# NOTE: try both
|
|
397
|
-
centered=True
|
|
398
|
+
centered=True,
|
|
398
399
|
)
|
|
399
400
|
intercept = y.mean()
|
|
400
401
|
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
|
|
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
|
|
|
113
113
|
|
|
114
114
|
with pm.Model() as markov_chain:
|
|
115
115
|
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
|
|
116
|
-
init_dist = pm.Categorical.dist(p
|
|
117
|
-
markov_chain = pmx.DiscreteMarkovChain(
|
|
116
|
+
init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
|
|
117
|
+
markov_chain = pmx.DiscreteMarkovChain(
|
|
118
|
+
"markov_chain", P=P, init_dist=init_dist, shape=(100,)
|
|
119
|
+
)
|
|
118
120
|
|
|
119
121
|
"""
|
|
120
122
|
|
|
@@ -194,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
|
|
|
194
196
|
state_rng = pytensor.shared(np.random.default_rng())
|
|
195
197
|
|
|
196
198
|
def transition(*args):
|
|
197
|
-
*states, transition_probs
|
|
199
|
+
old_rng, *states, transition_probs = args
|
|
198
200
|
p = transition_probs[tuple(states)]
|
|
199
201
|
next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
|
|
200
|
-
return
|
|
202
|
+
return next_rng, next_state
|
|
201
203
|
|
|
202
|
-
|
|
204
|
+
state_next_rng, markov_chain = pytensor.scan(
|
|
203
205
|
transition,
|
|
204
|
-
|
|
205
|
-
|
|
206
|
+
outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
|
|
207
|
+
non_sequences=[P_],
|
|
206
208
|
n_steps=steps_,
|
|
207
209
|
strict=True,
|
|
210
|
+
return_updates=False,
|
|
208
211
|
)
|
|
209
212
|
|
|
210
|
-
(state_next_rng,) = tuple(state_updates.values())
|
|
211
|
-
|
|
212
213
|
discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
|
|
213
214
|
|
|
214
215
|
discrete_mc_op = DiscreteMarkovChainRV(
|
|
@@ -237,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
|
|
|
237
238
|
n_lags = op.n_lags
|
|
238
239
|
|
|
239
240
|
def greedy_transition(*args):
|
|
240
|
-
*states, transition_probs
|
|
241
|
+
*states, transition_probs = args
|
|
241
242
|
p = transition_probs[tuple(states)]
|
|
242
243
|
return pt.argmax(p)
|
|
243
244
|
|
|
244
|
-
chain_moment
|
|
245
|
+
chain_moment = pytensor.scan(
|
|
245
246
|
greedy_transition,
|
|
246
|
-
non_sequences=[P
|
|
247
|
+
non_sequences=[P],
|
|
247
248
|
outputs_info=_make_outputs_info(n_lags, init_dist),
|
|
248
249
|
n_steps=steps,
|
|
249
250
|
strict=True,
|
|
251
|
+
return_updates=False,
|
|
250
252
|
)
|
|
251
253
|
chain_moment = pt.concatenate([init_dist_moment, chain_moment])
|
|
252
254
|
return chain_moment
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import arviz as az
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pymc
|
|
4
|
+
import pytensor
|
|
5
|
+
import pytensor.tensor as pt
|
|
6
|
+
|
|
7
|
+
from arviz import InferenceData
|
|
8
|
+
from better_optimize import basinhopping, minimize
|
|
9
|
+
from better_optimize.constants import minimize_method
|
|
10
|
+
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
|
|
11
|
+
from pymc.blocking import RaveledVars
|
|
12
|
+
from pymc.util import RandomSeed
|
|
13
|
+
from pytensor.tensor.variable import TensorVariable
|
|
14
|
+
|
|
15
|
+
from pymc_extras.inference.laplace_approx.idata import (
|
|
16
|
+
add_data_to_inference_data,
|
|
17
|
+
add_optimizer_result_to_inference_data,
|
|
18
|
+
)
|
|
19
|
+
from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
|
|
20
|
+
from pymc_extras.inference.laplace_approx.scipy_interface import (
|
|
21
|
+
scipy_optimize_funcs_from_loss,
|
|
22
|
+
set_optimizer_function_defaults,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def fit_dadvi(
|
|
27
|
+
model: Model | None = None,
|
|
28
|
+
n_fixed_draws: int = 30,
|
|
29
|
+
n_draws: int = 1000,
|
|
30
|
+
include_transformed: bool = False,
|
|
31
|
+
optimizer_method: minimize_method = "trust-ncg",
|
|
32
|
+
use_grad: bool | None = None,
|
|
33
|
+
use_hessp: bool | None = None,
|
|
34
|
+
use_hess: bool | None = None,
|
|
35
|
+
gradient_backend: str = "pytensor",
|
|
36
|
+
compile_kwargs: dict | None = None,
|
|
37
|
+
random_seed: RandomSeed = None,
|
|
38
|
+
progressbar: bool = True,
|
|
39
|
+
**optimizer_kwargs,
|
|
40
|
+
) -> az.InferenceData:
|
|
41
|
+
"""
|
|
42
|
+
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
|
|
43
|
+
|
|
44
|
+
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
model : pm.Model
|
|
49
|
+
The PyMC model to be fit. If None, the current model context is used.
|
|
50
|
+
|
|
51
|
+
n_fixed_draws : int
|
|
52
|
+
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
|
|
53
|
+
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
|
|
54
|
+
|
|
55
|
+
random_seed: int
|
|
56
|
+
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
|
|
57
|
+
the same result.
|
|
58
|
+
|
|
59
|
+
n_draws: int
|
|
60
|
+
The number of draws to return from the variational approximation.
|
|
61
|
+
|
|
62
|
+
include_transformed: bool
|
|
63
|
+
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
|
|
64
|
+
output.
|
|
65
|
+
|
|
66
|
+
optimizer_method: str
|
|
67
|
+
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
|
|
68
|
+
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
|
|
69
|
+
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
|
|
70
|
+
the optimum.
|
|
71
|
+
|
|
72
|
+
gradient_backend: str
|
|
73
|
+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
|
|
74
|
+
|
|
75
|
+
compile_kwargs: dict, optional
|
|
76
|
+
Additional keyword arguments to pass to `pytensor.function`
|
|
77
|
+
|
|
78
|
+
use_grad: bool, optional
|
|
79
|
+
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
|
|
80
|
+
|
|
81
|
+
use_hessp: bool, optional
|
|
82
|
+
If True, pass the hessian vector product to `scipy.optimize.minimize`.
|
|
83
|
+
|
|
84
|
+
use_hess: bool, optional
|
|
85
|
+
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
|
|
86
|
+
computation can be slow and memory-intensive if there are many parameters.
|
|
87
|
+
|
|
88
|
+
progressbar: bool
|
|
89
|
+
Whether or not to show a progress bar during optimization. Default is True.
|
|
90
|
+
|
|
91
|
+
optimizer_kwargs:
|
|
92
|
+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
|
|
93
|
+
that function for details.
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
:class:`~arviz.InferenceData`
|
|
98
|
+
The inference data containing the results of the DADVI algorithm.
|
|
99
|
+
|
|
100
|
+
References
|
|
101
|
+
----------
|
|
102
|
+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
|
|
103
|
+
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
model = pymc.modelcontext(model) if model is None else model
|
|
107
|
+
do_basinhopping = optimizer_method == "basinhopping"
|
|
108
|
+
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
|
|
109
|
+
|
|
110
|
+
if do_basinhopping:
|
|
111
|
+
# For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
|
|
112
|
+
# another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
|
|
113
|
+
# if one isn't provided.
|
|
114
|
+
|
|
115
|
+
optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
|
|
116
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
117
|
+
|
|
118
|
+
initial_point_dict = model.initial_point()
|
|
119
|
+
initial_point = DictToArrayBijection.map(initial_point_dict)
|
|
120
|
+
n_params = initial_point.data.shape[0]
|
|
121
|
+
|
|
122
|
+
var_params, objective = create_dadvi_graph(
|
|
123
|
+
model,
|
|
124
|
+
n_fixed_draws=n_fixed_draws,
|
|
125
|
+
random_seed=random_seed,
|
|
126
|
+
n_params=n_params,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
|
|
130
|
+
optimizer_method, use_grad, use_hess, use_hessp
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
|
|
134
|
+
loss=objective,
|
|
135
|
+
inputs=[var_params],
|
|
136
|
+
initial_point_dict=None,
|
|
137
|
+
use_grad=use_grad,
|
|
138
|
+
use_hessp=use_hessp,
|
|
139
|
+
use_hess=use_hess,
|
|
140
|
+
gradient_backend=gradient_backend,
|
|
141
|
+
compile_kwargs=compile_kwargs,
|
|
142
|
+
inputs_are_flat=True,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
dadvi_initial_point = {
|
|
146
|
+
f"{var_name}_mu": np.zeros_like(value).ravel()
|
|
147
|
+
for var_name, value in initial_point_dict.items()
|
|
148
|
+
}
|
|
149
|
+
dadvi_initial_point.update(
|
|
150
|
+
{
|
|
151
|
+
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
|
|
152
|
+
for var_name, value in initial_point_dict.items()
|
|
153
|
+
}
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
|
|
157
|
+
args = optimizer_kwargs.pop("args", ())
|
|
158
|
+
|
|
159
|
+
if do_basinhopping:
|
|
160
|
+
if "args" not in minimizer_kwargs:
|
|
161
|
+
minimizer_kwargs["args"] = args
|
|
162
|
+
if "hessp" not in minimizer_kwargs:
|
|
163
|
+
minimizer_kwargs["hessp"] = f_hessp
|
|
164
|
+
if "method" not in minimizer_kwargs:
|
|
165
|
+
minimizer_kwargs["method"] = optimizer_method
|
|
166
|
+
|
|
167
|
+
result = basinhopping(
|
|
168
|
+
func=f_fused,
|
|
169
|
+
x0=dadvi_initial_point.data,
|
|
170
|
+
progressbar=progressbar,
|
|
171
|
+
minimizer_kwargs=minimizer_kwargs,
|
|
172
|
+
**optimizer_kwargs,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
else:
|
|
176
|
+
result = minimize(
|
|
177
|
+
f=f_fused,
|
|
178
|
+
x0=dadvi_initial_point.data,
|
|
179
|
+
args=args,
|
|
180
|
+
method=optimizer_method,
|
|
181
|
+
hessp=f_hessp,
|
|
182
|
+
progressbar=progressbar,
|
|
183
|
+
**optimizer_kwargs,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
|
|
187
|
+
|
|
188
|
+
opt_var_params = result.x
|
|
189
|
+
opt_means, opt_log_sds = np.split(opt_var_params, 2)
|
|
190
|
+
|
|
191
|
+
posterior, unconstrained_posterior = draws_from_laplace_approx(
|
|
192
|
+
mean=opt_means,
|
|
193
|
+
standard_deviation=np.exp(opt_log_sds),
|
|
194
|
+
draws=n_draws,
|
|
195
|
+
model=model,
|
|
196
|
+
vectorize_draws=False,
|
|
197
|
+
return_unconstrained=include_transformed,
|
|
198
|
+
random_seed=random_seed,
|
|
199
|
+
)
|
|
200
|
+
idata = InferenceData(posterior=posterior)
|
|
201
|
+
if include_transformed:
|
|
202
|
+
idata.add_groups(unconstrained_posterior=unconstrained_posterior)
|
|
203
|
+
|
|
204
|
+
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
|
|
205
|
+
var_name_to_model_var.update(
|
|
206
|
+
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
idata = add_optimizer_result_to_inference_data(
|
|
210
|
+
idata=idata,
|
|
211
|
+
result=result,
|
|
212
|
+
method=optimizer_method,
|
|
213
|
+
mu=raveled_optimized,
|
|
214
|
+
model=model,
|
|
215
|
+
var_name_to_model_var=var_name_to_model_var,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
idata = add_data_to_inference_data(
|
|
219
|
+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
return idata
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def create_dadvi_graph(
|
|
226
|
+
model: Model,
|
|
227
|
+
n_params: int,
|
|
228
|
+
n_fixed_draws: int = 30,
|
|
229
|
+
random_seed: RandomSeed = None,
|
|
230
|
+
) -> tuple[TensorVariable, TensorVariable]:
|
|
231
|
+
"""
|
|
232
|
+
Sets up the DADVI graph in pytensor and returns it.
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
model : pm.Model
|
|
237
|
+
The PyMC model to be fit.
|
|
238
|
+
|
|
239
|
+
n_params: int
|
|
240
|
+
The total number of parameters in the model.
|
|
241
|
+
|
|
242
|
+
n_fixed_draws : int
|
|
243
|
+
The number of fixed draws to use.
|
|
244
|
+
|
|
245
|
+
random_seed: int
|
|
246
|
+
The random seed to use for the fixed draws.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
Tuple[TensorVariable, TensorVariable]
|
|
251
|
+
A tuple whose first element contains the variational parameters,
|
|
252
|
+
and whose second contains the DADVI objective.
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
# Make the fixed draws
|
|
256
|
+
generator = np.random.default_rng(seed=random_seed)
|
|
257
|
+
draws = generator.standard_normal(size=(n_fixed_draws, n_params))
|
|
258
|
+
|
|
259
|
+
inputs = model.continuous_value_vars + model.discrete_value_vars
|
|
260
|
+
initial_point_dict = model.initial_point()
|
|
261
|
+
logp = model.logp()
|
|
262
|
+
|
|
263
|
+
# Graph in terms of a flat input
|
|
264
|
+
[logp], flat_input = join_nonshared_inputs(
|
|
265
|
+
point=initial_point_dict, outputs=[logp], inputs=inputs
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
var_params = pt.vector(name="eta", shape=(2 * n_params,))
|
|
269
|
+
|
|
270
|
+
means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
|
|
271
|
+
|
|
272
|
+
draw_matrix = pt.constant(draws)
|
|
273
|
+
samples = means + pt.exp(log_sds) * draw_matrix
|
|
274
|
+
|
|
275
|
+
logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
|
|
276
|
+
|
|
277
|
+
mean_log_density = pt.mean(logp_vectorized_draws)
|
|
278
|
+
entropy = pt.sum(log_sds)
|
|
279
|
+
|
|
280
|
+
objective = -mean_log_density - entropy
|
|
281
|
+
|
|
282
|
+
return var_params, objective
|