arviz 0.20.0__tar.gz → 0.22.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {arviz-0.20.0 → arviz-0.22.0}/CHANGELOG.md +35 -0
- {arviz-0.20.0 → arviz-0.22.0}/CONTRIBUTING.md +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/PKG-INFO +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/__init__.py +8 -3
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/base.py +2 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/inference_data.py +57 -26
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_datatree.py +2 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_numpyro.py +112 -4
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/autocorrplot.py +12 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/__init__.py +8 -7
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/bpvplot.py +4 -3
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/densityplot.py +5 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/dotplot.py +5 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/essplot.py +4 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/forestplot.py +11 -4
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/hdiplot.py +7 -6
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/khatplot.py +4 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/lmplot.py +28 -6
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/mcseplot.py +2 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/pairplot.py +27 -52
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/ppcplot.py +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/rankplot.py +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/traceplot.py +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/violinplot.py +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/khatplot.py +8 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/lmplot.py +13 -7
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/pairplot.py +14 -22
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/bfplot.py +9 -26
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/bpvplot.py +10 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/hdiplot.py +5 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/lmplot.py +41 -14
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/pairplot.py +10 -3
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/plot_utils.py +5 -3
- arviz-0.22.0/arviz/preview.py +48 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/__init__.py +1 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/density_utils.py +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/diagnostics.py +18 -14
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/stats.py +105 -7
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_data.py +31 -11
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_diagnostics.py +5 -4
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats.py +53 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_numpyro.py +130 -3
- {arviz-0.20.0 → arviz-0.22.0}/arviz/utils.py +4 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/wrappers/base.py +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz.egg-info/PKG-INFO +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/arviz.egg-info/requires.txt +6 -6
- {arviz-0.20.0 → arviz-0.22.0}/requirements-dev.txt +1 -0
- {arviz-0.20.0 → arviz-0.22.0}/requirements-optional.txt +1 -1
- {arviz-0.20.0 → arviz-0.22.0}/requirements.txt +5 -5
- arviz-0.20.0/arviz/preview.py +0 -17
- {arviz-0.20.0 → arviz-0.22.0}/CODE_OF_CONDUCT.md +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/GOVERNANCE.md +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/LICENSE +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/MANIFEST.in +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/README.md +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/converters.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/datasets.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/example_data/code/radon/radon.json +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/example_data/data/centered_eight.nc +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/example_data/data/non_centered_eight.nc +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/example_data/data_local.json +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/example_data/data_remote.json +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_beanmachine.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_cmdstan.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_cmdstanpy.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_dict.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_emcee.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_json.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_netcdf.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_pyjags.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_pyro.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_pystan.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/io_zarr.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/data/utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/labels.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/autocorrplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/bfplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/compareplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/distcomparisonplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/distplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/ecdfplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/elpdplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/energyplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/kdeplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/loopitplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/parallelplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/posteriorplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/separationplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/autocorrplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/bfplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/compareplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/densityplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/distplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/dotplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/ecdfplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/elpdplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/energyplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/essplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/forestplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/hdiplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/kdeplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/loopitplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/mcseplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/parallelplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/posteriorplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/ppcplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/rankplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/separationplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/traceplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/tsplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/violinplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/compareplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/densityplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/distcomparisonplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/distplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/dotplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/ecdfplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/elpdplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/energyplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/essplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/forestplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/kdeplot.py +4 -4
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/khatplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/loopitplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/mcseplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/parallelplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/posteriorplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/ppcplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/rankplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/separationplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-bluish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-brownish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-colors.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-cyanish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-darkgrid.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-doc.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-docgrid.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-grayscale.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-greenish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-orangish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-plasmish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-purplish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-redish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-royish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-viridish.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-white.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/styles/arviz-whitegrid.mplstyle +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/traceplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/tsplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/plots/violinplot.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/py.typed +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/rcparams.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/sel_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/static/css/style.css +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/static/html/icons-svg-inline.html +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/ecdf_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/stats_refitting.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/stats/stats_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_data_zarr.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_diagnostics_numba.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_helpers.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_labels.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plot_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_rcparams.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_numba.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_utils.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/base_tests/test_utils_numba.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/conftest.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_beanmachine.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_cmdstan.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_cmdstanpy.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_emcee.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pyjags.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pyro.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pystan.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/tests/helpers.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/wrappers/__init__.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz/wrappers/wrap_pymc.py +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz.egg-info/SOURCES.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz.egg-info/dependency_links.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/arviz.egg-info/top_level.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/pyproject.toml +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/requirements-docs.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/requirements-external.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/requirements-test.txt +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/setup.cfg +0 -0
- {arviz-0.20.0 → arviz-0.22.0}/setup.py +0 -0
|
@@ -1,5 +1,40 @@
|
|
|
1
1
|
# Change Log
|
|
2
2
|
|
|
3
|
+
## v0.22.0 (2025 Jul 9)
|
|
4
|
+
|
|
5
|
+
### New features
|
|
6
|
+
- `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
|
|
7
|
+
- Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
|
|
8
|
+
|
|
9
|
+
### Maintenance and fixes
|
|
10
|
+
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
|
|
11
|
+
- Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
|
|
12
|
+
- Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
|
|
13
|
+
- Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407))
|
|
14
|
+
- Fix hdi_probs behaviour in 2d `plot_kde` ([2460](https://github.com/arviz-devs/arviz/pull/2460))
|
|
15
|
+
|
|
16
|
+
### Documentation
|
|
17
|
+
- Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
|
|
18
|
+
- Add migration guide page to help switch over to the new `arviz-xyz` libraries ([2459](https://github.com/arviz-devs/arviz/pull/2459))
|
|
19
|
+
|
|
20
|
+
## v0.21.0 (2025 Mar 06)
|
|
21
|
+
|
|
22
|
+
### New features
|
|
23
|
+
|
|
24
|
+
### Maintenance and fixes
|
|
25
|
+
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
|
|
26
|
+
- Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
|
|
27
|
+
- Splits Bayes Factor computation out from `az.plot_bf` into `az.bayes_factor` ([2402](https://github.com/arviz-devs/arviz/issues/2402))
|
|
28
|
+
- Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
|
|
29
|
+
- Add exception in `az.plot_hdi` for `x` of type `str` ([2413](https://github.com/arviz-devs/arviz/pull/2413))
|
|
30
|
+
|
|
31
|
+
### Documentation
|
|
32
|
+
- Add example of ECDF comparison plot to gallery ([2178](https://github.com/arviz-devs/arviz/pull/2178))
|
|
33
|
+
- Change Twitter to X, including the icon ([2418](https://github.com/arviz-devs/arviz/pull/2418))
|
|
34
|
+
- Update Bokeh link in Installation.rst ([2425](https://github.com/arviz-devs/arviz/pull/2425))
|
|
35
|
+
- Add missing periods to the ArviZ community page ([2426](https://github.com/arviz-devs/arviz/pull/2426))
|
|
36
|
+
- Fix missing docstring ([2430](https://github.com/arviz-devs/arviz/pull/2430))
|
|
37
|
+
|
|
3
38
|
## v0.20.0 (2024 Sep 28)
|
|
4
39
|
|
|
5
40
|
### New features
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
# Contributing to ArviZ
|
|
2
2
|
This document outlines only the most common contributions.
|
|
3
3
|
Please see the [Contributing guide](https://python.arviz.org/en/latest/contributing/index.html)
|
|
4
|
-
on our documentation for a better view of how can
|
|
4
|
+
on our documentation for a better view of how you can contribute to ArviZ.
|
|
5
5
|
We welcome a wide range of contributions, not only code!
|
|
6
|
+
Even improving documentation or fixing typos is a valuable contribution to ArviZ.
|
|
6
7
|
|
|
7
8
|
## Reporting issues
|
|
8
9
|
If you encounter any bug or incorrect behaviour while using ArviZ,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
|
|
2
2
|
"""ArviZ is a library for exploratory analysis of Bayesian models."""
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.22.0"
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
@@ -8,6 +8,7 @@ import os
|
|
|
8
8
|
from matplotlib.colors import LinearSegmentedColormap
|
|
9
9
|
from matplotlib.pyplot import style
|
|
10
10
|
import matplotlib as mpl
|
|
11
|
+
from packaging import version
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Logger(logging.Logger):
|
|
@@ -41,8 +42,12 @@ from . import preview
|
|
|
41
42
|
|
|
42
43
|
# add ArviZ's styles to matplotlib's styles
|
|
43
44
|
_arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
|
|
44
|
-
|
|
45
|
-
style.
|
|
45
|
+
if version.parse(mpl.__version__) >= version.parse("3.11.0.dev0"):
|
|
46
|
+
style.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
47
|
+
style.reload_library()
|
|
48
|
+
else:
|
|
49
|
+
style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
50
|
+
style.core.reload_library()
|
|
46
51
|
|
|
47
52
|
|
|
48
53
|
if not logging.root.handlers:
|
|
@@ -201,10 +201,10 @@ def generate_dims_coords(
|
|
|
201
201
|
for i, dim_len in enumerate(shape):
|
|
202
202
|
idx = i + len([dim for dim in default_dims if dim in dims])
|
|
203
203
|
if len(dims) < idx + 1:
|
|
204
|
-
dim_name = f"{var_name}_dim_{
|
|
204
|
+
dim_name = f"{var_name}_dim_{i}"
|
|
205
205
|
dims.append(dim_name)
|
|
206
206
|
elif dims[idx] is None:
|
|
207
|
-
dim_name = f"{var_name}_dim_{
|
|
207
|
+
dim_name = f"{var_name}_dim_{i}"
|
|
208
208
|
dims[idx] = dim_name
|
|
209
209
|
dim_name = dims[idx]
|
|
210
210
|
if dim_name not in coords:
|
|
@@ -102,6 +102,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
102
102
|
def __init__(
|
|
103
103
|
self,
|
|
104
104
|
attrs: Union[None, Mapping[Any, Any]] = None,
|
|
105
|
+
warn_on_custom_groups: bool = False,
|
|
105
106
|
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
|
|
106
107
|
) -> None:
|
|
107
108
|
"""Initialize InferenceData object from keyword xarray datasets.
|
|
@@ -110,6 +111,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
110
111
|
----------
|
|
111
112
|
attrs : dict
|
|
112
113
|
sets global attribute for InferenceData object.
|
|
114
|
+
warn_on_custom_groups : bool, default False
|
|
115
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
116
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
113
117
|
kwargs :
|
|
114
118
|
Keyword arguments of xarray datasets
|
|
115
119
|
|
|
@@ -154,9 +158,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
154
158
|
for key in kwargs:
|
|
155
159
|
if key not in SUPPORTED_GROUPS_ALL:
|
|
156
160
|
key_list.append(key)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
161
|
+
if warn_on_custom_groups:
|
|
162
|
+
warnings.warn(
|
|
163
|
+
f"{key} group is not defined in the InferenceData scheme", UserWarning
|
|
164
|
+
)
|
|
160
165
|
for key in key_list:
|
|
161
166
|
dataset = kwargs[key]
|
|
162
167
|
dataset_warmup = None
|
|
@@ -527,24 +532,27 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
527
532
|
return filename
|
|
528
533
|
|
|
529
534
|
def to_datatree(self):
|
|
530
|
-
"""Convert InferenceData object to a :class:`~
|
|
535
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
|
|
531
536
|
try:
|
|
532
|
-
from
|
|
533
|
-
except
|
|
534
|
-
raise
|
|
535
|
-
"
|
|
537
|
+
from xarray import DataTree
|
|
538
|
+
except ImportError as err:
|
|
539
|
+
raise ImportError(
|
|
540
|
+
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
|
|
541
|
+
"Update to xarray>=2024.11.0"
|
|
536
542
|
) from err
|
|
537
543
|
return DataTree.from_dict({group: ds for group, ds in self.items()})
|
|
538
544
|
|
|
539
545
|
@staticmethod
|
|
540
546
|
def from_datatree(datatree):
|
|
541
|
-
"""Create an InferenceData object from a :class:`~
|
|
547
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
542
548
|
|
|
543
549
|
Parameters
|
|
544
550
|
----------
|
|
545
551
|
datatree : DataTree
|
|
546
552
|
"""
|
|
547
|
-
return InferenceData(
|
|
553
|
+
return InferenceData(
|
|
554
|
+
**{group: child.to_dataset() for group, child in datatree.children.items()}
|
|
555
|
+
)
|
|
548
556
|
|
|
549
557
|
def to_dict(self, groups=None, filter_groups=None):
|
|
550
558
|
"""Convert InferenceData to a dictionary following xarray naming conventions.
|
|
@@ -792,12 +800,20 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
792
800
|
----------
|
|
793
801
|
https://zarr.readthedocs.io/
|
|
794
802
|
"""
|
|
795
|
-
try:
|
|
803
|
+
try:
|
|
796
804
|
import zarr
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
raise ImportError(
|
|
805
|
+
except ImportError as err:
|
|
806
|
+
raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
807
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
808
|
+
raise ImportError(
|
|
809
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
|
|
810
|
+
)
|
|
811
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
812
|
+
raise ImportError(
|
|
813
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
814
|
+
"'dt = InfereceData.to_datatree' followed by 'dt.to_zarr()' "
|
|
815
|
+
"(needs xarray>=2024.11.0)"
|
|
816
|
+
)
|
|
801
817
|
|
|
802
818
|
# Check store type and create store if necessary
|
|
803
819
|
if store is None:
|
|
@@ -846,10 +862,18 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
846
862
|
"""
|
|
847
863
|
try:
|
|
848
864
|
import zarr
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
raise ImportError(
|
|
865
|
+
except ImportError as err:
|
|
866
|
+
raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
867
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
868
|
+
raise ImportError(
|
|
869
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
|
|
870
|
+
)
|
|
871
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
872
|
+
raise ImportError(
|
|
873
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
874
|
+
"'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
|
|
875
|
+
"(needs xarray>=2024.11.0)"
|
|
876
|
+
)
|
|
853
877
|
|
|
854
878
|
# Check store type and create store if necessary
|
|
855
879
|
if isinstance(store, str):
|
|
@@ -1467,7 +1491,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1467
1491
|
else:
|
|
1468
1492
|
return out
|
|
1469
1493
|
|
|
1470
|
-
def add_groups(
|
|
1494
|
+
def add_groups(
|
|
1495
|
+
self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
|
|
1496
|
+
):
|
|
1471
1497
|
"""Add new groups to InferenceData object.
|
|
1472
1498
|
|
|
1473
1499
|
Parameters
|
|
@@ -1479,6 +1505,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1479
1505
|
dims : dict of {str : list of str}, optional
|
|
1480
1506
|
Dimensions of each variable. The keys are variable names, values are lists of
|
|
1481
1507
|
coordinates.
|
|
1508
|
+
warn_on_custom_groups : bool, default False
|
|
1509
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
1510
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1482
1511
|
kwargs : dict, optional
|
|
1483
1512
|
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
|
|
1484
1513
|
|
|
@@ -1521,9 +1550,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1521
1550
|
import xarray as xr
|
|
1522
1551
|
from xarray_einstats.stats import XrDiscreteRV
|
|
1523
1552
|
from scipy.stats import poisson
|
|
1524
|
-
dist = XrDiscreteRV(poisson)
|
|
1525
|
-
log_lik =
|
|
1526
|
-
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
|
|
1553
|
+
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
|
|
1554
|
+
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
|
|
1527
1555
|
idata2.add_groups({"log_likelihood": log_lik})
|
|
1528
1556
|
idata2
|
|
1529
1557
|
|
|
@@ -1542,7 +1570,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1542
1570
|
if repeated_groups:
|
|
1543
1571
|
raise ValueError(f"{repeated_groups} group(s) already exists.")
|
|
1544
1572
|
for group, dataset in group_dict.items():
|
|
1545
|
-
if group not in SUPPORTED_GROUPS_ALL:
|
|
1573
|
+
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1546
1574
|
warnings.warn(
|
|
1547
1575
|
f"The group {group} is not defined in the InferenceData scheme",
|
|
1548
1576
|
UserWarning,
|
|
@@ -1597,7 +1625,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1597
1625
|
else:
|
|
1598
1626
|
self._groups.append(group)
|
|
1599
1627
|
|
|
1600
|
-
def extend(self, other, join="left"):
|
|
1628
|
+
def extend(self, other, join="left", warn_on_custom_groups=False):
|
|
1601
1629
|
"""Extend InferenceData with groups from another InferenceData.
|
|
1602
1630
|
|
|
1603
1631
|
Parameters
|
|
@@ -1608,6 +1636,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1608
1636
|
Defines how the two decide which group to keep when the same group is
|
|
1609
1637
|
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
|
|
1610
1638
|
will keep the group in ``other`` and discard the one in ``self``.
|
|
1639
|
+
warn_on_custom_groups : bool, default False
|
|
1640
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
1641
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1611
1642
|
|
|
1612
1643
|
Examples
|
|
1613
1644
|
--------
|
|
@@ -1651,7 +1682,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1651
1682
|
for group in other._groups_all: # pylint: disable=protected-access
|
|
1652
1683
|
if hasattr(self, group) and join == "left":
|
|
1653
1684
|
continue
|
|
1654
|
-
if group not in SUPPORTED_GROUPS_ALL:
|
|
1685
|
+
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1655
1686
|
warnings.warn(
|
|
1656
1687
|
f"{group} group is not defined in the InferenceData scheme", UserWarning
|
|
1657
1688
|
)
|
|
@@ -4,7 +4,7 @@ from .inference_data import InferenceData
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def to_datatree(data):
|
|
7
|
-
"""Convert InferenceData object to a :class:`~
|
|
7
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`.
|
|
8
8
|
|
|
9
9
|
Parameters
|
|
10
10
|
----------
|
|
@@ -14,7 +14,7 @@ def to_datatree(data):
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def from_datatree(datatree):
|
|
17
|
-
"""Create an InferenceData object from a :class:`~
|
|
17
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
18
18
|
|
|
19
19
|
Parameters
|
|
20
20
|
----------
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""NumPyro-specific conversion code."""
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
import logging
|
|
4
|
-
from typing import Callable, Optional
|
|
5
|
+
from typing import Any, Callable, Optional, Dict, List, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
@@ -13,6 +14,70 @@ from .inference_data import InferenceData
|
|
|
13
14
|
_log = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
|
18
|
+
merged = defaultdict(list)
|
|
19
|
+
|
|
20
|
+
for k, v in dims_a.items():
|
|
21
|
+
merged[k].extend(v)
|
|
22
|
+
|
|
23
|
+
for k, v in dims_b.items():
|
|
24
|
+
merged[k].extend(v)
|
|
25
|
+
|
|
26
|
+
# Convert back to a regular dict
|
|
27
|
+
return dict(merged)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def infer_dims(
|
|
31
|
+
model: Callable,
|
|
32
|
+
model_args: Optional[Tuple[Any, ...]] = None,
|
|
33
|
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
34
|
+
) -> Dict[str, List[str]]:
|
|
35
|
+
|
|
36
|
+
from numpyro import handlers, distributions as dist
|
|
37
|
+
from numpyro.ops.pytree import PytreeTrace
|
|
38
|
+
from numpyro.infer.initialization import init_to_sample
|
|
39
|
+
import jax
|
|
40
|
+
|
|
41
|
+
model_args = tuple() if model_args is None else model_args
|
|
42
|
+
model_kwargs = dict() if model_args is None else model_kwargs
|
|
43
|
+
|
|
44
|
+
def _get_dist_name(fn):
|
|
45
|
+
if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
|
|
46
|
+
return _get_dist_name(fn.base_dist)
|
|
47
|
+
return type(fn).__name__
|
|
48
|
+
|
|
49
|
+
def get_trace():
|
|
50
|
+
# We use `init_to_sample` to get around ImproperUniform distribution,
|
|
51
|
+
# which does not have `sample` method.
|
|
52
|
+
subs_model = handlers.substitute(
|
|
53
|
+
handlers.seed(model, 0),
|
|
54
|
+
substitute_fn=init_to_sample,
|
|
55
|
+
)
|
|
56
|
+
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
|
|
57
|
+
# Work around an issue where jax.eval_shape does not work
|
|
58
|
+
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
|
|
59
|
+
# Here we will remove `fn` and store its name in the trace.
|
|
60
|
+
for _, site in trace.items():
|
|
61
|
+
if site["type"] == "sample":
|
|
62
|
+
site["fn_name"] = _get_dist_name(site.pop("fn"))
|
|
63
|
+
elif site["type"] == "deterministic":
|
|
64
|
+
site["fn_name"] = "Deterministic"
|
|
65
|
+
return PytreeTrace(trace)
|
|
66
|
+
|
|
67
|
+
# We use eval_shape to avoid any array computation.
|
|
68
|
+
trace = jax.eval_shape(get_trace).trace
|
|
69
|
+
|
|
70
|
+
named_dims = {}
|
|
71
|
+
|
|
72
|
+
for name, site in trace.items():
|
|
73
|
+
batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
|
|
74
|
+
event_dims = list(site.get("infer", {}).get("event_dims", []))
|
|
75
|
+
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
|
|
76
|
+
named_dims[name] = batch_dims + event_dims
|
|
77
|
+
|
|
78
|
+
return named_dims
|
|
79
|
+
|
|
80
|
+
|
|
16
81
|
class NumPyroConverter:
|
|
17
82
|
"""Encapsulate NumPyro specific logic."""
|
|
18
83
|
|
|
@@ -36,6 +101,7 @@ class NumPyroConverter:
|
|
|
36
101
|
coords=None,
|
|
37
102
|
dims=None,
|
|
38
103
|
pred_dims=None,
|
|
104
|
+
extra_event_dims=None,
|
|
39
105
|
num_chains=1,
|
|
40
106
|
):
|
|
41
107
|
"""Convert NumPyro data into an InferenceData object.
|
|
@@ -58,9 +124,12 @@ class NumPyroConverter:
|
|
|
58
124
|
coords : dict[str] -> list[str]
|
|
59
125
|
Map of dimensions to coordinates
|
|
60
126
|
dims : dict[str] -> list[str]
|
|
61
|
-
Map variable names to their coordinates
|
|
127
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
62
128
|
pred_dims: dict
|
|
63
129
|
Dims for predictions data. Map variable names to their coordinates.
|
|
130
|
+
extra_event_dims: dict
|
|
131
|
+
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
|
|
132
|
+
their coordinates.
|
|
64
133
|
num_chains: int
|
|
65
134
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
66
135
|
"""
|
|
@@ -80,6 +149,7 @@ class NumPyroConverter:
|
|
|
80
149
|
self.coords = coords
|
|
81
150
|
self.dims = dims
|
|
82
151
|
self.pred_dims = pred_dims
|
|
152
|
+
self.extra_event_dims = extra_event_dims
|
|
83
153
|
self.numpyro = numpyro
|
|
84
154
|
|
|
85
155
|
def arbitrary_element(dct):
|
|
@@ -107,6 +177,10 @@ class NumPyroConverter:
|
|
|
107
177
|
# model arguments and keyword arguments
|
|
108
178
|
self._args = self.posterior._args # pylint: disable=protected-access
|
|
109
179
|
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
|
|
180
|
+
self.dims = self.dims if self.dims is not None else self.infer_dims()
|
|
181
|
+
self.pred_dims = (
|
|
182
|
+
self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
|
|
183
|
+
)
|
|
110
184
|
else:
|
|
111
185
|
self.nchains = num_chains
|
|
112
186
|
get_from = None
|
|
@@ -325,6 +399,23 @@ class NumPyroConverter:
|
|
|
325
399
|
}
|
|
326
400
|
)
|
|
327
401
|
|
|
402
|
+
@requires("posterior")
|
|
403
|
+
@requires("model")
|
|
404
|
+
def infer_dims(self) -> Dict[str, List[str]]:
|
|
405
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
406
|
+
if self.extra_event_dims:
|
|
407
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
408
|
+
return dims
|
|
409
|
+
|
|
410
|
+
@requires("posterior")
|
|
411
|
+
@requires("model")
|
|
412
|
+
@requires("predictions")
|
|
413
|
+
def infer_pred_dims(self) -> Dict[str, List[str]]:
|
|
414
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
415
|
+
if self.extra_event_dims:
|
|
416
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
417
|
+
return dims
|
|
418
|
+
|
|
328
419
|
|
|
329
420
|
def from_numpyro(
|
|
330
421
|
posterior=None,
|
|
@@ -339,10 +430,25 @@ def from_numpyro(
|
|
|
339
430
|
coords=None,
|
|
340
431
|
dims=None,
|
|
341
432
|
pred_dims=None,
|
|
433
|
+
extra_event_dims=None,
|
|
342
434
|
num_chains=1,
|
|
343
435
|
):
|
|
344
436
|
"""Convert NumPyro data into an InferenceData object.
|
|
345
437
|
|
|
438
|
+
If no dims are provided, this will infer batch dim names from NumPyro model plates.
|
|
439
|
+
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
|
|
440
|
+
can be provided in numpyro.sample, i.e.::
|
|
441
|
+
|
|
442
|
+
# equivalent to dims entry, {"gamma": ["groups"]}
|
|
443
|
+
gamma = numpyro.sample(
|
|
444
|
+
"gamma",
|
|
445
|
+
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
|
|
446
|
+
infer={"event_dims":["groups"]}
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
|
|
450
|
+
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
|
|
451
|
+
|
|
346
452
|
For a usage example read the
|
|
347
453
|
:ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
|
|
348
454
|
|
|
@@ -364,9 +470,10 @@ def from_numpyro(
|
|
|
364
470
|
coords : dict[str] -> list[str]
|
|
365
471
|
Map of dimensions to coordinates
|
|
366
472
|
dims : dict[str] -> list[str]
|
|
367
|
-
Map variable names to their coordinates
|
|
473
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
368
474
|
pred_dims: dict
|
|
369
|
-
Dims for predictions data. Map variable names to their coordinates.
|
|
475
|
+
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
|
|
476
|
+
infer dims if this is not provided
|
|
370
477
|
num_chains: int
|
|
371
478
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
372
479
|
"""
|
|
@@ -382,5 +489,6 @@ def from_numpyro(
|
|
|
382
489
|
coords=coords,
|
|
383
490
|
dims=dims,
|
|
384
491
|
pred_dims=pred_dims,
|
|
492
|
+
extra_event_dims=extra_event_dims,
|
|
385
493
|
num_chains=num_chains,
|
|
386
494
|
).to_inference_data()
|
|
@@ -4,7 +4,7 @@ from ..data import convert_to_dataset
|
|
|
4
4
|
from ..labels import BaseLabeller
|
|
5
5
|
from ..sel_utils import xarray_var_iter
|
|
6
6
|
from ..rcparams import rcParams
|
|
7
|
-
from ..utils import _var_names
|
|
7
|
+
from ..utils import _var_names, get_coords
|
|
8
8
|
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
9
9
|
|
|
10
10
|
|
|
@@ -14,6 +14,7 @@ def plot_autocorr(
|
|
|
14
14
|
filter_vars=None,
|
|
15
15
|
max_lag=None,
|
|
16
16
|
combined=False,
|
|
17
|
+
coords=None,
|
|
17
18
|
grid=None,
|
|
18
19
|
figsize=None,
|
|
19
20
|
textsize=None,
|
|
@@ -42,6 +43,8 @@ def plot_autocorr(
|
|
|
42
43
|
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
43
44
|
interpret `var_names` as regular expressions on the real variables names. See
|
|
44
45
|
:ref:`this section <common_filter_vars>` for usage examples.
|
|
46
|
+
coords: mapping, optional
|
|
47
|
+
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
45
48
|
max_lag : int, optional
|
|
46
49
|
Maximum lag to calculate autocorrelation. By Default, the plot displays the
|
|
47
50
|
first 100 lag or the total number of draws, whichever is smaller.
|
|
@@ -124,11 +127,18 @@ def plot_autocorr(
|
|
|
124
127
|
if max_lag is None:
|
|
125
128
|
max_lag = min(100, data["draw"].shape[0])
|
|
126
129
|
|
|
130
|
+
if coords is None:
|
|
131
|
+
coords = {}
|
|
132
|
+
|
|
127
133
|
if labeller is None:
|
|
128
134
|
labeller = BaseLabeller()
|
|
129
135
|
|
|
130
136
|
plotters = filter_plotters_list(
|
|
131
|
-
list(
|
|
137
|
+
list(
|
|
138
|
+
xarray_var_iter(
|
|
139
|
+
get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
|
|
140
|
+
)
|
|
141
|
+
),
|
|
132
142
|
"plot_autocorr",
|
|
133
143
|
)
|
|
134
144
|
rows, cols = default_grid(len(plotters), grid=grid)
|
|
@@ -213,10 +213,11 @@ def _copy_docstring(lib, function):
|
|
|
213
213
|
|
|
214
214
|
|
|
215
215
|
# TODO: try copying substitutions too, or autoreplace them ourselves
|
|
216
|
-
output_notebook.__doc__
|
|
217
|
-
"
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
"
|
|
221
|
-
|
|
222
|
-
|
|
216
|
+
if output_notebook.__doc__ is not None: # if run with python -OO, __doc__ is stripped
|
|
217
|
+
output_notebook.__doc__ += "\n\n" + _copy_docstring(
|
|
218
|
+
"bokeh.plotting", "output_notebook"
|
|
219
|
+
).replace("|save|", "save").replace("|show|", "show")
|
|
220
|
+
output_file.__doc__ += "\n\n" + _copy_docstring("bokeh.plotting", "output_file").replace(
|
|
221
|
+
"|save|", "save"
|
|
222
|
+
).replace("|show|", "show")
|
|
223
|
+
ColumnDataSource.__doc__ += "\n\n" + _copy_docstring("bokeh.models", "ColumnDataSource")
|
|
@@ -41,6 +41,7 @@ def plot_bpv(
|
|
|
41
41
|
plot_ref_kwargs,
|
|
42
42
|
backend_kwargs,
|
|
43
43
|
show,
|
|
44
|
+
smoothing,
|
|
44
45
|
):
|
|
45
46
|
"""Bokeh bpv plot."""
|
|
46
47
|
if backend_kwargs is None:
|
|
@@ -90,6 +91,9 @@ def plot_bpv(
|
|
|
90
91
|
obs_vals = obs_vals.flatten()
|
|
91
92
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
|
92
93
|
|
|
94
|
+
if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
|
|
95
|
+
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
96
|
+
|
|
93
97
|
if kind == "p_value":
|
|
94
98
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
|
|
95
99
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
@@ -115,9 +119,6 @@ def plot_bpv(
|
|
|
115
119
|
)
|
|
116
120
|
|
|
117
121
|
elif kind == "u_value":
|
|
118
|
-
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
119
|
-
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
120
|
-
|
|
121
122
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
|
|
122
123
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
123
124
|
ax_i.line(x_s, tstat_pit_dens, color=color)
|
|
@@ -225,7 +225,11 @@ def _d_helper(
|
|
|
225
225
|
|
|
226
226
|
if point_estimate is not None:
|
|
227
227
|
est = calculate_point_estimate(point_estimate, vec, bw, circular)
|
|
228
|
-
plotted.append(
|
|
228
|
+
plotted.append(
|
|
229
|
+
ax.scatter(
|
|
230
|
+
est, 0, marker="circle", fill_color=color, line_color="black", size=markersize
|
|
231
|
+
)
|
|
232
|
+
)
|
|
229
233
|
|
|
230
234
|
_title = Title()
|
|
231
235
|
_title.text = vname
|
|
@@ -47,7 +47,10 @@ def plot_dot(
|
|
|
47
47
|
|
|
48
48
|
if plot_kwargs is None:
|
|
49
49
|
plot_kwargs = {}
|
|
50
|
-
|
|
50
|
+
else:
|
|
51
|
+
plot_kwargs = plot_kwargs.copy()
|
|
52
|
+
plot_kwargs.setdefault("color", dotcolor)
|
|
53
|
+
plot_kwargs.setdefault("marker", "circle")
|
|
51
54
|
|
|
52
55
|
if linewidth is None:
|
|
53
56
|
linewidth = auto_linewidth
|
|
@@ -95,7 +98,7 @@ def plot_dot(
|
|
|
95
98
|
stack_locs, stack_count = wilkinson_algorithm(values, binwidth)
|
|
96
99
|
x, y = layout_stacks(stack_locs, stack_count, binwidth, stackratio, rotated)
|
|
97
100
|
|
|
98
|
-
ax.
|
|
101
|
+
ax.scatter(x, y, radius=dotsize * (binwidth / 2), **plot_kwargs, radius_dimension="y")
|
|
99
102
|
if rotated:
|
|
100
103
|
ax.xaxis.major_tick_line_color = None
|
|
101
104
|
ax.xaxis.minor_tick_line_color = None
|
|
@@ -73,12 +73,14 @@ def plot_ess(
|
|
|
73
73
|
for (var_name, selection, isel, x), ax_ in zip(
|
|
74
74
|
plotters, (item for item in ax.flatten() if item is not None)
|
|
75
75
|
):
|
|
76
|
-
bulk_points = ax_.
|
|
76
|
+
bulk_points = ax_.scatter(np.asarray(xdata), np.asarray(x), marker="circle", size=6)
|
|
77
77
|
if kind == "evolution":
|
|
78
78
|
bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
|
|
79
79
|
ess_tail = ess_tail_dataset[var_name].sel(**selection)
|
|
80
80
|
tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange")
|
|
81
|
-
tail_line = ax_.
|
|
81
|
+
tail_line = ax_.scatter(
|
|
82
|
+
np.asarray(xdata), np.asarray(ess_tail), marker="circle", size=6, color="orange"
|
|
83
|
+
)
|
|
82
84
|
elif rug:
|
|
83
85
|
if rug_kwargs is None:
|
|
84
86
|
rug_kwargs = {}
|