arviz 0.21.0__tar.gz → 0.23.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arviz-0.21.0 → arviz-0.23.0}/CHANGELOG.md +26 -1
- {arviz-0.21.0 → arviz-0.23.0}/CONTRIBUTING.md +2 -1
- {arviz-0.21.0 → arviz-0.23.0}/PKG-INFO +36 -3
- {arviz-0.21.0 → arviz-0.23.0}/arviz/__init__.py +49 -4
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/converters.py +11 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/inference_data.py +46 -24
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_datatree.py +2 -2
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_numpyro.py +116 -5
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pyjags.py +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/autocorrplot.py +12 -2
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/hdiplot.py +7 -6
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/lmplot.py +19 -3
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/pairplot.py +18 -48
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/khatplot.py +8 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/lmplot.py +13 -7
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/pairplot.py +14 -22
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/bpvplot.py +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/dotplot.py +2 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/forestplot.py +16 -4
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/lmplot.py +41 -14
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/pairplot.py +10 -3
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/ppcplot.py +1 -1
- arviz-0.23.0/arviz/preview.py +58 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/rcparams.py +2 -2
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/density_utils.py +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats.py +31 -34
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_data.py +25 -4
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats.py +42 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_numpyro.py +154 -4
- {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/base.py +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/PKG-INFO +36 -3
- {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/requires.txt +6 -6
- {arviz-0.21.0 → arviz-0.23.0}/requirements-dev.txt +1 -0
- {arviz-0.21.0 → arviz-0.23.0}/requirements-optional.txt +1 -1
- {arviz-0.21.0 → arviz-0.23.0}/requirements.txt +5 -5
- arviz-0.21.0/arviz/preview.py +0 -48
- {arviz-0.21.0 → arviz-0.23.0}/CODE_OF_CONDUCT.md +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/GOVERNANCE.md +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/LICENSE +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/MANIFEST.in +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/README.md +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/base.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/datasets.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/code/radon/radon.json +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data/centered_eight.nc +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data/non_centered_eight.nc +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data_local.json +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data_remote.json +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_beanmachine.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_cmdstan.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_cmdstanpy.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_dict.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_emcee.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_json.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_netcdf.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pyro.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pystan.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_zarr.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/data/utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/labels.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/autocorrplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/bfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/bpvplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/compareplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/densityplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/distcomparisonplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/distplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/dotplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/ecdfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/elpdplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/energyplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/essplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/forestplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/kdeplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/khatplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/loopitplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/mcseplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/parallelplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/posteriorplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/ppcplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/rankplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/separationplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/traceplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/violinplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/autocorrplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/bfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/bpvplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/compareplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/densityplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/distplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/dotplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/ecdfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/elpdplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/energyplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/essplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/forestplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/hdiplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/kdeplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/loopitplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/mcseplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/parallelplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/posteriorplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/ppcplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/rankplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/separationplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/traceplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/tsplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/violinplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/bfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/compareplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/densityplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/distcomparisonplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/distplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/ecdfplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/elpdplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/energyplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/essplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/hdiplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/kdeplot.py +4 -4
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/khatplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/loopitplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/mcseplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/parallelplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/plot_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/posteriorplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/rankplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/separationplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-bluish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-brownish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-colors.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-cyanish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-darkgrid.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-doc.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-docgrid.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-grayscale.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-greenish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-orangish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-plasmish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-purplish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-redish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-royish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-viridish.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-white.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-whitegrid.mplstyle +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/traceplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/tsplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/violinplot.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/py.typed +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/sel_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/static/css/style.css +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/static/html/icons-svg-inline.html +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/diagnostics.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/ecdf_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats_refitting.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_data_zarr.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_diagnostics.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_diagnostics_numba.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_helpers.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_labels.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plot_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_rcparams.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_numba.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_utils_numba.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/conftest.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_beanmachine.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_cmdstan.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_cmdstanpy.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_emcee.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pyjags.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pyro.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pystan.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/helpers.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/utils.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/__init__.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/wrap_pymc.py +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/SOURCES.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/dependency_links.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/top_level.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/pyproject.toml +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/requirements-docs.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/requirements-external.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/requirements-test.txt +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/setup.cfg +0 -0
- {arviz-0.21.0 → arviz-0.23.0}/setup.py +0 -0
|
@@ -1,5 +1,30 @@
|
|
|
1
1
|
# Change Log
|
|
2
2
|
|
|
3
|
+
## v0.23.0 (2025 Des 9)
|
|
4
|
+
|
|
5
|
+
### Maintenance and fixes
|
|
6
|
+
- Fix numpyro jax incompatibility. ([2465](https://github.com/arviz-devs/arviz/pull/2465))
|
|
7
|
+
- Avoid closing unloaded files in `from_netcdf()` ([2463](https://github.com/arviz-devs/arviz/issues/2463))
|
|
8
|
+
- Fix sign error in lp parsed in from_numpyro ([2468](https://github.com/arviz-devs/arviz/issues/2468))
|
|
9
|
+
- Fix attrs persistance in idata-datatree conversions ([2476](https://github.com/arviz-devs/arviz/issues/2476))
|
|
10
|
+
|
|
11
|
+
## v0.22.0 (2025 Jul 9)
|
|
12
|
+
|
|
13
|
+
### New features
|
|
14
|
+
- `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
|
|
15
|
+
- Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
|
|
16
|
+
|
|
17
|
+
### Maintenance and fixes
|
|
18
|
+
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
|
|
19
|
+
- Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
|
|
20
|
+
- Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
|
|
21
|
+
- Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407))
|
|
22
|
+
- Fix hdi_probs behaviour in 2d `plot_kde` ([2460](https://github.com/arviz-devs/arviz/pull/2460))
|
|
23
|
+
|
|
24
|
+
### Documentation
|
|
25
|
+
- Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
|
|
26
|
+
- Add migration guide page to help switch over to the new `arviz-xyz` libraries ([2459](https://github.com/arviz-devs/arviz/pull/2459))
|
|
27
|
+
|
|
3
28
|
## v0.21.0 (2025 Mar 06)
|
|
4
29
|
|
|
5
30
|
### New features
|
|
@@ -8,7 +33,7 @@
|
|
|
8
33
|
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
|
|
9
34
|
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
|
|
10
35
|
- Splits Bayes Factor computation out from `az.plot_bf` into `az.bayes_factor` ([2402](https://github.com/arviz-devs/arviz/issues/2402))
|
|
11
|
-
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
|
|
36
|
+
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
|
|
12
37
|
- Add exception in `az.plot_hdi` for `x` of type `str` ([2413](https://github.com/arviz-devs/arviz/pull/2413))
|
|
13
38
|
|
|
14
39
|
### Documentation
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
# Contributing to ArviZ
|
|
2
2
|
This document outlines only the most common contributions.
|
|
3
3
|
Please see the [Contributing guide](https://python.arviz.org/en/latest/contributing/index.html)
|
|
4
|
-
on our documentation for a better view of how can
|
|
4
|
+
on our documentation for a better view of how you can contribute to ArviZ.
|
|
5
5
|
We welcome a wide range of contributions, not only code!
|
|
6
|
+
Even improving documentation or fixing typos is a valuable contribution to ArviZ.
|
|
6
7
|
|
|
7
8
|
## Reporting issues
|
|
8
9
|
If you encounter any bug or incorrect behaviour while using ArviZ,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.23.0
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -20,9 +20,42 @@ Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
|
20
20
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
21
21
|
Requires-Python: >=3.10
|
|
22
22
|
Description-Content-Type: text/markdown
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Dist: setuptools>=60.0.0
|
|
25
|
+
Requires-Dist: matplotlib>=3.8
|
|
26
|
+
Requires-Dist: numpy>=1.26.0
|
|
27
|
+
Requires-Dist: scipy>=1.11.0
|
|
28
|
+
Requires-Dist: packaging
|
|
29
|
+
Requires-Dist: pandas>=2.1.0
|
|
30
|
+
Requires-Dist: xarray>=2023.7.0
|
|
31
|
+
Requires-Dist: h5netcdf>=1.0.2
|
|
32
|
+
Requires-Dist: typing_extensions>=4.1.0
|
|
33
|
+
Requires-Dist: xarray-einstats>=0.3
|
|
23
34
|
Provides-Extra: all
|
|
35
|
+
Requires-Dist: numba; extra == "all"
|
|
36
|
+
Requires-Dist: netcdf4; extra == "all"
|
|
37
|
+
Requires-Dist: bokeh>=3; extra == "all"
|
|
38
|
+
Requires-Dist: contourpy; extra == "all"
|
|
39
|
+
Requires-Dist: ujson; extra == "all"
|
|
40
|
+
Requires-Dist: dask[distributed]; extra == "all"
|
|
41
|
+
Requires-Dist: zarr<3,>=2.5.0; extra == "all"
|
|
42
|
+
Requires-Dist: xarray>=2024.11.0; extra == "all"
|
|
43
|
+
Requires-Dist: dm-tree>=0.1.8; extra == "all"
|
|
24
44
|
Provides-Extra: preview
|
|
25
|
-
|
|
45
|
+
Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
|
|
46
|
+
Requires-Dist: arviz-stats[xarray]; extra == "preview"
|
|
47
|
+
Requires-Dist: arviz-plots; extra == "preview"
|
|
48
|
+
Dynamic: author
|
|
49
|
+
Dynamic: classifier
|
|
50
|
+
Dynamic: description
|
|
51
|
+
Dynamic: description-content-type
|
|
52
|
+
Dynamic: home-page
|
|
53
|
+
Dynamic: license
|
|
54
|
+
Dynamic: license-file
|
|
55
|
+
Dynamic: provides-extra
|
|
56
|
+
Dynamic: requires-dist
|
|
57
|
+
Dynamic: requires-python
|
|
58
|
+
Dynamic: summary
|
|
26
59
|
|
|
27
60
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
|
|
28
61
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ_white.png#gh-dark-mode-only" width=200></img>
|
|
@@ -1,13 +1,54 @@
|
|
|
1
1
|
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
|
|
2
2
|
"""ArviZ is a library for exploratory analysis of Bayesian models."""
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.23.0"
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import re
|
|
7
8
|
|
|
8
9
|
from matplotlib.colors import LinearSegmentedColormap
|
|
9
10
|
from matplotlib.pyplot import style
|
|
10
11
|
import matplotlib as mpl
|
|
12
|
+
from packaging import version
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _warn_once_per_day():
|
|
16
|
+
from .preview import info
|
|
17
|
+
|
|
18
|
+
# skip warning if all 3 arviz subpackages are already installed
|
|
19
|
+
pat = re.compile(r"arviz_(base|stats|plots) available")
|
|
20
|
+
if len(pat.findall(info)) == 3:
|
|
21
|
+
return
|
|
22
|
+
|
|
23
|
+
import datetime
|
|
24
|
+
from warnings import warn
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
|
|
27
|
+
warning_dir = Path.home() / "arviz_data"
|
|
28
|
+
warning_dir.mkdir(exist_ok=True)
|
|
29
|
+
|
|
30
|
+
stamp_file = warning_dir / "daily_warning"
|
|
31
|
+
today = datetime.date.today()
|
|
32
|
+
|
|
33
|
+
if stamp_file.exists():
|
|
34
|
+
last_date = datetime.date.fromisoformat(stamp_file.read_text().strip())
|
|
35
|
+
else:
|
|
36
|
+
last_date = None
|
|
37
|
+
|
|
38
|
+
if last_date != today:
|
|
39
|
+
warn(
|
|
40
|
+
"\nArviZ is undergoing a major refactor to improve flexibility and extensibility "
|
|
41
|
+
"while maintaining a user-friendly interface."
|
|
42
|
+
"\nSome upcoming changes may be backward incompatible."
|
|
43
|
+
"\nFor details and migration guidance, visit: "
|
|
44
|
+
"https://python.arviz.org/en/latest/user_guide/migration_guide.html",
|
|
45
|
+
FutureWarning,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
stamp_file.write_text(today.isoformat())
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
_warn_once_per_day()
|
|
11
52
|
|
|
12
53
|
|
|
13
54
|
class Logger(logging.Logger):
|
|
@@ -41,8 +82,12 @@ from . import preview
|
|
|
41
82
|
|
|
42
83
|
# add ArviZ's styles to matplotlib's styles
|
|
43
84
|
_arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
|
|
44
|
-
|
|
45
|
-
style.
|
|
85
|
+
if version.parse(mpl.__version__) >= version.parse("3.11.0.dev0"):
|
|
86
|
+
style.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
87
|
+
style.reload_library()
|
|
88
|
+
else:
|
|
89
|
+
style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
90
|
+
style.core.reload_library()
|
|
46
91
|
|
|
47
92
|
|
|
48
93
|
if not logging.root.handlers:
|
|
@@ -328,4 +373,4 @@ except ModuleNotFoundError:
|
|
|
328
373
|
|
|
329
374
|
|
|
330
375
|
# clean namespace
|
|
331
|
-
del os, logging, LinearSegmentedColormap, Logger, mpl
|
|
376
|
+
del os, re, logging, version, LinearSegmentedColormap, Logger, mpl
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import xarray as xr
|
|
5
|
+
import pandas as pd
|
|
5
6
|
|
|
6
7
|
try:
|
|
7
8
|
from tree import is_nested
|
|
@@ -44,6 +45,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
44
45
|
| dict: creates an xarray dataset as the only group
|
|
45
46
|
| numpy array: creates an xarray dataset as the only group, gives the
|
|
46
47
|
array an arbitrary name
|
|
48
|
+
| object with __array__: converts to numpy array, then creates an xarray dataset as
|
|
49
|
+
the only group, gives the array an arbitrary name
|
|
47
50
|
group : str
|
|
48
51
|
If `obj` is a dict or numpy array, assigns the resulting xarray
|
|
49
52
|
dataset to this group. Default: "posterior".
|
|
@@ -115,6 +118,13 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
115
118
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
116
119
|
elif isinstance(obj, np.ndarray):
|
|
117
120
|
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
121
|
+
elif (
|
|
122
|
+
hasattr(obj, "__array__")
|
|
123
|
+
and callable(getattr(obj, "__array__"))
|
|
124
|
+
and (not isinstance(obj, pd.DataFrame))
|
|
125
|
+
):
|
|
126
|
+
obj = obj.__array__()
|
|
127
|
+
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
118
128
|
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
|
|
119
129
|
if group == "sample_stats":
|
|
120
130
|
kwargs["posterior"] = kwargs.pop(group)
|
|
@@ -129,6 +139,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
129
139
|
"pytree (if 'dm-tree' is installed)",
|
|
130
140
|
"netcdf filename",
|
|
131
141
|
"numpy array",
|
|
142
|
+
"object with __array__",
|
|
132
143
|
"pystan fit",
|
|
133
144
|
"emcee fit",
|
|
134
145
|
"pyro mcmc fit",
|
|
@@ -430,11 +430,12 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
430
430
|
if re.search(key, group):
|
|
431
431
|
group_kws = kws
|
|
432
432
|
group_kws.setdefault("engine", engine)
|
|
433
|
-
|
|
434
|
-
|
|
433
|
+
data = xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws)
|
|
434
|
+
if rcParams["data.load"] == "eager":
|
|
435
|
+
with data:
|
|
435
436
|
groups[group] = data.load()
|
|
436
|
-
|
|
437
|
-
|
|
437
|
+
else:
|
|
438
|
+
groups[group] = data
|
|
438
439
|
|
|
439
440
|
with xr.open_dataset(filename, engine=engine, group=base_group) as data:
|
|
440
441
|
attrs.update(data.load().attrs)
|
|
@@ -532,24 +533,30 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
532
533
|
return filename
|
|
533
534
|
|
|
534
535
|
def to_datatree(self):
|
|
535
|
-
"""Convert InferenceData object to a :class:`~
|
|
536
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
|
|
536
537
|
try:
|
|
537
|
-
from
|
|
538
|
-
except
|
|
539
|
-
raise
|
|
540
|
-
"
|
|
538
|
+
from xarray import DataTree
|
|
539
|
+
except ImportError as err:
|
|
540
|
+
raise ImportError(
|
|
541
|
+
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
|
|
542
|
+
"Update to xarray>=2024.11.0"
|
|
541
543
|
) from err
|
|
542
|
-
|
|
544
|
+
dt = DataTree.from_dict({group: ds for group, ds in self.items()})
|
|
545
|
+
dt.attrs = self.attrs
|
|
546
|
+
return dt
|
|
543
547
|
|
|
544
548
|
@staticmethod
|
|
545
549
|
def from_datatree(datatree):
|
|
546
|
-
"""Create an InferenceData object from a :class:`~
|
|
550
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
547
551
|
|
|
548
552
|
Parameters
|
|
549
553
|
----------
|
|
550
554
|
datatree : DataTree
|
|
551
555
|
"""
|
|
552
|
-
return InferenceData(
|
|
556
|
+
return InferenceData(
|
|
557
|
+
attrs=datatree.attrs,
|
|
558
|
+
**{group: child.to_dataset() for group, child in datatree.children.items()},
|
|
559
|
+
)
|
|
553
560
|
|
|
554
561
|
def to_dict(self, groups=None, filter_groups=None):
|
|
555
562
|
"""Convert InferenceData to a dictionary following xarray naming conventions.
|
|
@@ -797,12 +804,20 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
797
804
|
----------
|
|
798
805
|
https://zarr.readthedocs.io/
|
|
799
806
|
"""
|
|
800
|
-
try:
|
|
807
|
+
try:
|
|
801
808
|
import zarr
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
raise ImportError(
|
|
809
|
+
except ImportError as err:
|
|
810
|
+
raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
811
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
812
|
+
raise ImportError(
|
|
813
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
|
|
814
|
+
)
|
|
815
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
816
|
+
raise ImportError(
|
|
817
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
818
|
+
"'dt = InferenceData.to_datatree' followed by 'dt.to_zarr()' "
|
|
819
|
+
"(needs xarray>=2024.11.0)"
|
|
820
|
+
)
|
|
806
821
|
|
|
807
822
|
# Check store type and create store if necessary
|
|
808
823
|
if store is None:
|
|
@@ -851,10 +866,18 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
851
866
|
"""
|
|
852
867
|
try:
|
|
853
868
|
import zarr
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
raise ImportError(
|
|
869
|
+
except ImportError as err:
|
|
870
|
+
raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
871
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
872
|
+
raise ImportError(
|
|
873
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
|
|
874
|
+
)
|
|
875
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
876
|
+
raise ImportError(
|
|
877
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
878
|
+
"'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
|
|
879
|
+
"(needs xarray>=2024.11.0)"
|
|
880
|
+
)
|
|
858
881
|
|
|
859
882
|
# Check store type and create store if necessary
|
|
860
883
|
if isinstance(store, str):
|
|
@@ -1531,9 +1554,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1531
1554
|
import xarray as xr
|
|
1532
1555
|
from xarray_einstats.stats import XrDiscreteRV
|
|
1533
1556
|
from scipy.stats import poisson
|
|
1534
|
-
dist = XrDiscreteRV(poisson)
|
|
1535
|
-
log_lik =
|
|
1536
|
-
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
|
|
1557
|
+
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
|
|
1558
|
+
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
|
|
1537
1559
|
idata2.add_groups({"log_likelihood": log_lik})
|
|
1538
1560
|
idata2
|
|
1539
1561
|
|
|
@@ -4,7 +4,7 @@ from .inference_data import InferenceData
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def to_datatree(data):
|
|
7
|
-
"""Convert InferenceData object to a :class:`~
|
|
7
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`.
|
|
8
8
|
|
|
9
9
|
Parameters
|
|
10
10
|
----------
|
|
@@ -14,7 +14,7 @@ def to_datatree(data):
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def from_datatree(datatree):
|
|
17
|
-
"""Create an InferenceData object from a :class:`~
|
|
17
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
18
18
|
|
|
19
19
|
Parameters
|
|
20
20
|
----------
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""NumPyro-specific conversion code."""
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
import logging
|
|
4
|
-
from typing import Callable, Optional
|
|
5
|
+
from typing import Any, Callable, Optional, Dict, List, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
@@ -13,6 +14,70 @@ from .inference_data import InferenceData
|
|
|
13
14
|
_log = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
|
18
|
+
merged = defaultdict(list)
|
|
19
|
+
|
|
20
|
+
for k, v in dims_a.items():
|
|
21
|
+
merged[k].extend(v)
|
|
22
|
+
|
|
23
|
+
for k, v in dims_b.items():
|
|
24
|
+
merged[k].extend(v)
|
|
25
|
+
|
|
26
|
+
# Convert back to a regular dict
|
|
27
|
+
return dict(merged)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def infer_dims(
|
|
31
|
+
model: Callable,
|
|
32
|
+
model_args: Optional[Tuple[Any, ...]] = None,
|
|
33
|
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
34
|
+
) -> Dict[str, List[str]]:
|
|
35
|
+
|
|
36
|
+
from numpyro import handlers, distributions as dist
|
|
37
|
+
from numpyro.ops.pytree import PytreeTrace
|
|
38
|
+
from numpyro.infer.initialization import init_to_sample
|
|
39
|
+
import jax
|
|
40
|
+
|
|
41
|
+
model_args = tuple() if model_args is None else model_args
|
|
42
|
+
model_kwargs = dict() if model_args is None else model_kwargs
|
|
43
|
+
|
|
44
|
+
def _get_dist_name(fn):
|
|
45
|
+
if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
|
|
46
|
+
return _get_dist_name(fn.base_dist)
|
|
47
|
+
return type(fn).__name__
|
|
48
|
+
|
|
49
|
+
def get_trace():
|
|
50
|
+
# We use `init_to_sample` to get around ImproperUniform distribution,
|
|
51
|
+
# which does not have `sample` method.
|
|
52
|
+
subs_model = handlers.substitute(
|
|
53
|
+
handlers.seed(model, 0),
|
|
54
|
+
substitute_fn=init_to_sample,
|
|
55
|
+
)
|
|
56
|
+
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
|
|
57
|
+
# Work around an issue where jax.eval_shape does not work
|
|
58
|
+
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
|
|
59
|
+
# Here we will remove `fn` and store its name in the trace.
|
|
60
|
+
for _, site in trace.items():
|
|
61
|
+
if site["type"] == "sample":
|
|
62
|
+
site["fn_name"] = _get_dist_name(site.pop("fn"))
|
|
63
|
+
elif site["type"] == "deterministic":
|
|
64
|
+
site["fn_name"] = "Deterministic"
|
|
65
|
+
return PytreeTrace(trace)
|
|
66
|
+
|
|
67
|
+
# We use eval_shape to avoid any array computation.
|
|
68
|
+
trace = jax.eval_shape(get_trace).trace
|
|
69
|
+
|
|
70
|
+
named_dims = {}
|
|
71
|
+
|
|
72
|
+
for name, site in trace.items():
|
|
73
|
+
batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
|
|
74
|
+
event_dims = list(site.get("infer", {}).get("event_dims", []))
|
|
75
|
+
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
|
|
76
|
+
named_dims[name] = batch_dims + event_dims
|
|
77
|
+
|
|
78
|
+
return named_dims
|
|
79
|
+
|
|
80
|
+
|
|
16
81
|
class NumPyroConverter:
|
|
17
82
|
"""Encapsulate NumPyro specific logic."""
|
|
18
83
|
|
|
@@ -36,6 +101,7 @@ class NumPyroConverter:
|
|
|
36
101
|
coords=None,
|
|
37
102
|
dims=None,
|
|
38
103
|
pred_dims=None,
|
|
104
|
+
extra_event_dims=None,
|
|
39
105
|
num_chains=1,
|
|
40
106
|
):
|
|
41
107
|
"""Convert NumPyro data into an InferenceData object.
|
|
@@ -58,9 +124,12 @@ class NumPyroConverter:
|
|
|
58
124
|
coords : dict[str] -> list[str]
|
|
59
125
|
Map of dimensions to coordinates
|
|
60
126
|
dims : dict[str] -> list[str]
|
|
61
|
-
Map variable names to their coordinates
|
|
127
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
62
128
|
pred_dims: dict
|
|
63
129
|
Dims for predictions data. Map variable names to their coordinates.
|
|
130
|
+
extra_event_dims: dict
|
|
131
|
+
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
|
|
132
|
+
their coordinates.
|
|
64
133
|
num_chains: int
|
|
65
134
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
66
135
|
"""
|
|
@@ -80,6 +149,7 @@ class NumPyroConverter:
|
|
|
80
149
|
self.coords = coords
|
|
81
150
|
self.dims = dims
|
|
82
151
|
self.pred_dims = pred_dims
|
|
152
|
+
self.extra_event_dims = extra_event_dims
|
|
83
153
|
self.numpyro = numpyro
|
|
84
154
|
|
|
85
155
|
def arbitrary_element(dct):
|
|
@@ -107,6 +177,10 @@ class NumPyroConverter:
|
|
|
107
177
|
# model arguments and keyword arguments
|
|
108
178
|
self._args = self.posterior._args # pylint: disable=protected-access
|
|
109
179
|
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
|
|
180
|
+
self.dims = self.dims if self.dims is not None else self.infer_dims()
|
|
181
|
+
self.pred_dims = (
|
|
182
|
+
self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
|
|
183
|
+
)
|
|
110
184
|
else:
|
|
111
185
|
self.nchains = num_chains
|
|
112
186
|
get_from = None
|
|
@@ -167,7 +241,10 @@ class NumPyroConverter:
|
|
|
167
241
|
continue
|
|
168
242
|
name = rename_key.get(stat, stat)
|
|
169
243
|
value = value.copy()
|
|
170
|
-
|
|
244
|
+
if stat == "potential_energy":
|
|
245
|
+
data[name] = -value
|
|
246
|
+
else:
|
|
247
|
+
data[name] = value
|
|
171
248
|
if stat == "num_steps":
|
|
172
249
|
data["tree_depth"] = np.log2(value).astype(int) + 1
|
|
173
250
|
return dict_to_dataset(
|
|
@@ -325,6 +402,23 @@ class NumPyroConverter:
|
|
|
325
402
|
}
|
|
326
403
|
)
|
|
327
404
|
|
|
405
|
+
@requires("posterior")
|
|
406
|
+
@requires("model")
|
|
407
|
+
def infer_dims(self) -> Dict[str, List[str]]:
|
|
408
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
409
|
+
if self.extra_event_dims:
|
|
410
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
411
|
+
return dims
|
|
412
|
+
|
|
413
|
+
@requires("posterior")
|
|
414
|
+
@requires("model")
|
|
415
|
+
@requires("predictions")
|
|
416
|
+
def infer_pred_dims(self) -> Dict[str, List[str]]:
|
|
417
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
418
|
+
if self.extra_event_dims:
|
|
419
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
420
|
+
return dims
|
|
421
|
+
|
|
328
422
|
|
|
329
423
|
def from_numpyro(
|
|
330
424
|
posterior=None,
|
|
@@ -339,10 +433,25 @@ def from_numpyro(
|
|
|
339
433
|
coords=None,
|
|
340
434
|
dims=None,
|
|
341
435
|
pred_dims=None,
|
|
436
|
+
extra_event_dims=None,
|
|
342
437
|
num_chains=1,
|
|
343
438
|
):
|
|
344
439
|
"""Convert NumPyro data into an InferenceData object.
|
|
345
440
|
|
|
441
|
+
If no dims are provided, this will infer batch dim names from NumPyro model plates.
|
|
442
|
+
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
|
|
443
|
+
can be provided in numpyro.sample, i.e.::
|
|
444
|
+
|
|
445
|
+
# equivalent to dims entry, {"gamma": ["groups"]}
|
|
446
|
+
gamma = numpyro.sample(
|
|
447
|
+
"gamma",
|
|
448
|
+
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
|
|
449
|
+
infer={"event_dims":["groups"]}
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
|
|
453
|
+
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
|
|
454
|
+
|
|
346
455
|
For a usage example read the
|
|
347
456
|
:ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
|
|
348
457
|
|
|
@@ -364,9 +473,10 @@ def from_numpyro(
|
|
|
364
473
|
coords : dict[str] -> list[str]
|
|
365
474
|
Map of dimensions to coordinates
|
|
366
475
|
dims : dict[str] -> list[str]
|
|
367
|
-
Map variable names to their coordinates
|
|
476
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
368
477
|
pred_dims: dict
|
|
369
|
-
Dims for predictions data. Map variable names to their coordinates.
|
|
478
|
+
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
|
|
479
|
+
infer dims if this is not provided
|
|
370
480
|
num_chains: int
|
|
371
481
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
372
482
|
"""
|
|
@@ -382,5 +492,6 @@ def from_numpyro(
|
|
|
382
492
|
coords=coords,
|
|
383
493
|
dims=dims,
|
|
384
494
|
pred_dims=pred_dims,
|
|
495
|
+
extra_event_dims=extra_event_dims,
|
|
385
496
|
num_chains=num_chains,
|
|
386
497
|
).to_inference_data()
|
|
@@ -277,7 +277,7 @@ def _extract_arviz_dict_from_inference_data(
|
|
|
277
277
|
|
|
278
278
|
|
|
279
279
|
def _convert_arviz_dict_to_pyjags_dict(
|
|
280
|
-
samples: tp.Mapping[str, np.ndarray]
|
|
280
|
+
samples: tp.Mapping[str, np.ndarray],
|
|
281
281
|
) -> tp.Mapping[str, np.ndarray]:
|
|
282
282
|
"""
|
|
283
283
|
Convert and ArviZ dictionary to a PyJAGS dictionary.
|
|
@@ -4,7 +4,7 @@ from ..data import convert_to_dataset
|
|
|
4
4
|
from ..labels import BaseLabeller
|
|
5
5
|
from ..sel_utils import xarray_var_iter
|
|
6
6
|
from ..rcparams import rcParams
|
|
7
|
-
from ..utils import _var_names
|
|
7
|
+
from ..utils import _var_names, get_coords
|
|
8
8
|
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
9
9
|
|
|
10
10
|
|
|
@@ -14,6 +14,7 @@ def plot_autocorr(
|
|
|
14
14
|
filter_vars=None,
|
|
15
15
|
max_lag=None,
|
|
16
16
|
combined=False,
|
|
17
|
+
coords=None,
|
|
17
18
|
grid=None,
|
|
18
19
|
figsize=None,
|
|
19
20
|
textsize=None,
|
|
@@ -42,6 +43,8 @@ def plot_autocorr(
|
|
|
42
43
|
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
43
44
|
interpret `var_names` as regular expressions on the real variables names. See
|
|
44
45
|
:ref:`this section <common_filter_vars>` for usage examples.
|
|
46
|
+
coords: mapping, optional
|
|
47
|
+
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
45
48
|
max_lag : int, optional
|
|
46
49
|
Maximum lag to calculate autocorrelation. By Default, the plot displays the
|
|
47
50
|
first 100 lag or the total number of draws, whichever is smaller.
|
|
@@ -124,11 +127,18 @@ def plot_autocorr(
|
|
|
124
127
|
if max_lag is None:
|
|
125
128
|
max_lag = min(100, data["draw"].shape[0])
|
|
126
129
|
|
|
130
|
+
if coords is None:
|
|
131
|
+
coords = {}
|
|
132
|
+
|
|
127
133
|
if labeller is None:
|
|
128
134
|
labeller = BaseLabeller()
|
|
129
135
|
|
|
130
136
|
plotters = filter_plotters_list(
|
|
131
|
-
list(
|
|
137
|
+
list(
|
|
138
|
+
xarray_var_iter(
|
|
139
|
+
get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
|
|
140
|
+
)
|
|
141
|
+
),
|
|
132
142
|
"plot_autocorr",
|
|
133
143
|
)
|
|
134
144
|
rows, cols = default_grid(len(plotters), grid=grid)
|
|
@@ -21,9 +21,13 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
|
|
|
21
21
|
plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
|
|
22
22
|
plot_kwargs.setdefault("alpha", 0)
|
|
23
23
|
|
|
24
|
-
fill_kwargs = {} if fill_kwargs is None else fill_kwargs
|
|
25
|
-
|
|
26
|
-
fill_kwargs
|
|
24
|
+
fill_kwargs = {} if fill_kwargs is None else fill_kwargs.copy()
|
|
25
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
26
|
+
if "color" in fill_kwargs and "fill_color" not in fill_kwargs:
|
|
27
|
+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.pop("color"))
|
|
28
|
+
else:
|
|
29
|
+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.get("fill_color", color))
|
|
30
|
+
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0.5))
|
|
27
31
|
|
|
28
32
|
figsize, *_ = _scale_fig_size(figsize, None)
|
|
29
33
|
|
|
@@ -38,9 +42,6 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
|
|
|
38
42
|
plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
|
|
39
43
|
plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))
|
|
40
44
|
|
|
41
|
-
fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
|
|
42
|
-
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))
|
|
43
|
-
|
|
44
45
|
ax.patch(
|
|
45
46
|
np.concatenate((x_data, x_data[::-1])),
|
|
46
47
|
np.concatenate((y_data[:, 0], y_data[:, 1][::-1])),
|