arviz 0.21.0__tar.gz → 0.22.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. {arviz-0.21.0 → arviz-0.22.0}/CHANGELOG.md +18 -1
  2. {arviz-0.21.0 → arviz-0.22.0}/CONTRIBUTING.md +2 -1
  3. {arviz-0.21.0 → arviz-0.22.0}/PKG-INFO +1 -1
  4. {arviz-0.21.0 → arviz-0.22.0}/arviz/__init__.py +8 -3
  5. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/inference_data.py +37 -19
  6. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_datatree.py +2 -2
  7. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_numpyro.py +112 -4
  8. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/autocorrplot.py +12 -2
  9. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/hdiplot.py +7 -6
  10. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/lmplot.py +19 -3
  11. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/pairplot.py +18 -48
  12. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/khatplot.py +8 -1
  13. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/lmplot.py +13 -7
  14. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/pairplot.py +14 -22
  15. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/lmplot.py +41 -14
  16. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/pairplot.py +10 -3
  17. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/density_utils.py +1 -1
  18. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/stats.py +19 -7
  19. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_data.py +0 -4
  20. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  21. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plots_matplotlib.py +77 -1
  22. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats.py +42 -1
  23. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_numpyro.py +130 -3
  24. {arviz-0.21.0 → arviz-0.22.0}/arviz/wrappers/base.py +1 -1
  25. {arviz-0.21.0 → arviz-0.22.0}/arviz/wrappers/wrap_stan.py +1 -1
  26. {arviz-0.21.0 → arviz-0.22.0}/arviz.egg-info/PKG-INFO +1 -1
  27. {arviz-0.21.0 → arviz-0.22.0}/arviz.egg-info/requires.txt +6 -6
  28. {arviz-0.21.0 → arviz-0.22.0}/requirements-dev.txt +1 -0
  29. {arviz-0.21.0 → arviz-0.22.0}/requirements-optional.txt +1 -1
  30. {arviz-0.21.0 → arviz-0.22.0}/requirements.txt +5 -5
  31. {arviz-0.21.0 → arviz-0.22.0}/CODE_OF_CONDUCT.md +0 -0
  32. {arviz-0.21.0 → arviz-0.22.0}/GOVERNANCE.md +0 -0
  33. {arviz-0.21.0 → arviz-0.22.0}/LICENSE +0 -0
  34. {arviz-0.21.0 → arviz-0.22.0}/MANIFEST.in +0 -0
  35. {arviz-0.21.0 → arviz-0.22.0}/README.md +0 -0
  36. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/__init__.py +0 -0
  37. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/base.py +0 -0
  38. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/converters.py +0 -0
  39. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/datasets.py +0 -0
  40. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/example_data/code/radon/radon.json +0 -0
  41. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/example_data/data/centered_eight.nc +0 -0
  42. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/example_data/data/non_centered_eight.nc +0 -0
  43. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/example_data/data_local.json +0 -0
  44. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/example_data/data_remote.json +0 -0
  45. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_beanmachine.py +0 -0
  46. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_cmdstan.py +0 -0
  47. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_cmdstanpy.py +0 -0
  48. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_dict.py +0 -0
  49. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_emcee.py +0 -0
  50. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_json.py +0 -0
  51. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_netcdf.py +0 -0
  52. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_pyjags.py +0 -0
  53. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_pyro.py +0 -0
  54. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_pystan.py +0 -0
  55. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/io_zarr.py +0 -0
  56. {arviz-0.21.0 → arviz-0.22.0}/arviz/data/utils.py +0 -0
  57. {arviz-0.21.0 → arviz-0.22.0}/arviz/labels.py +0 -0
  58. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/__init__.py +0 -0
  59. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/__init__.py +0 -0
  60. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/__init__.py +0 -0
  61. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/autocorrplot.py +0 -0
  62. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/bfplot.py +0 -0
  63. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/bpvplot.py +0 -0
  64. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/compareplot.py +0 -0
  65. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/densityplot.py +0 -0
  66. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/distcomparisonplot.py +0 -0
  67. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/distplot.py +0 -0
  68. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/dotplot.py +0 -0
  69. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/ecdfplot.py +0 -0
  70. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/elpdplot.py +0 -0
  71. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/energyplot.py +0 -0
  72. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/essplot.py +0 -0
  73. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/forestplot.py +0 -0
  74. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/kdeplot.py +0 -0
  75. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/khatplot.py +0 -0
  76. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/loopitplot.py +0 -0
  77. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/mcseplot.py +0 -0
  78. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/parallelplot.py +0 -0
  79. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/posteriorplot.py +0 -0
  80. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/ppcplot.py +0 -0
  81. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/rankplot.py +0 -0
  82. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/separationplot.py +0 -0
  83. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/traceplot.py +0 -0
  84. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/bokeh/violinplot.py +0 -0
  85. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/__init__.py +0 -0
  86. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/autocorrplot.py +0 -0
  87. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/bfplot.py +0 -0
  88. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/bpvplot.py +0 -0
  89. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/compareplot.py +0 -0
  90. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/densityplot.py +0 -0
  91. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -0
  92. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/distplot.py +0 -0
  93. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/dotplot.py +0 -0
  94. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/ecdfplot.py +0 -0
  95. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/elpdplot.py +0 -0
  96. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/energyplot.py +0 -0
  97. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/essplot.py +0 -0
  98. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/forestplot.py +0 -0
  99. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/hdiplot.py +0 -0
  100. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/kdeplot.py +0 -0
  101. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/loopitplot.py +0 -0
  102. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/mcseplot.py +0 -0
  103. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/parallelplot.py +0 -0
  104. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/posteriorplot.py +0 -0
  105. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/ppcplot.py +0 -0
  106. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/rankplot.py +0 -0
  107. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/separationplot.py +0 -0
  108. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/traceplot.py +0 -0
  109. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/tsplot.py +0 -0
  110. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/backends/matplotlib/violinplot.py +0 -0
  111. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/bfplot.py +0 -0
  112. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/bpvplot.py +0 -0
  113. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/compareplot.py +0 -0
  114. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/densityplot.py +0 -0
  115. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/distcomparisonplot.py +0 -0
  116. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/distplot.py +0 -0
  117. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/dotplot.py +0 -0
  118. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/ecdfplot.py +0 -0
  119. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/elpdplot.py +0 -0
  120. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/energyplot.py +0 -0
  121. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/essplot.py +0 -0
  122. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/forestplot.py +0 -0
  123. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/hdiplot.py +0 -0
  124. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/kdeplot.py +4 -4
  125. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/khatplot.py +0 -0
  126. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/loopitplot.py +0 -0
  127. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/mcseplot.py +0 -0
  128. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/parallelplot.py +0 -0
  129. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/plot_utils.py +0 -0
  130. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/posteriorplot.py +0 -0
  131. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/ppcplot.py +0 -0
  132. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/rankplot.py +0 -0
  133. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/separationplot.py +0 -0
  134. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-bluish.mplstyle +0 -0
  135. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-brownish.mplstyle +0 -0
  136. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-colors.mplstyle +0 -0
  137. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-cyanish.mplstyle +0 -0
  138. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-darkgrid.mplstyle +0 -0
  139. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-doc.mplstyle +0 -0
  140. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-docgrid.mplstyle +0 -0
  141. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-grayscale.mplstyle +0 -0
  142. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-greenish.mplstyle +0 -0
  143. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-orangish.mplstyle +0 -0
  144. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-plasmish.mplstyle +0 -0
  145. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-purplish.mplstyle +0 -0
  146. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-redish.mplstyle +0 -0
  147. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-royish.mplstyle +0 -0
  148. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-viridish.mplstyle +0 -0
  149. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-white.mplstyle +0 -0
  150. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/styles/arviz-whitegrid.mplstyle +0 -0
  151. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/traceplot.py +0 -0
  152. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/tsplot.py +0 -0
  153. {arviz-0.21.0 → arviz-0.22.0}/arviz/plots/violinplot.py +0 -0
  154. {arviz-0.21.0 → arviz-0.22.0}/arviz/preview.py +0 -0
  155. {arviz-0.21.0 → arviz-0.22.0}/arviz/py.typed +0 -0
  156. {arviz-0.21.0 → arviz-0.22.0}/arviz/rcparams.py +0 -0
  157. {arviz-0.21.0 → arviz-0.22.0}/arviz/sel_utils.py +0 -0
  158. {arviz-0.21.0 → arviz-0.22.0}/arviz/static/css/style.css +0 -0
  159. {arviz-0.21.0 → arviz-0.22.0}/arviz/static/html/icons-svg-inline.html +0 -0
  160. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/__init__.py +0 -0
  161. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/diagnostics.py +0 -0
  162. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/ecdf_utils.py +0 -0
  163. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/stats_refitting.py +0 -0
  164. {arviz-0.21.0 → arviz-0.22.0}/arviz/stats/stats_utils.py +0 -0
  165. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/__init__.py +0 -0
  166. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/__init__.py +0 -0
  167. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_data_zarr.py +0 -0
  168. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_diagnostics.py +0 -0
  169. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_diagnostics_numba.py +0 -0
  170. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_helpers.py +0 -0
  171. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_labels.py +0 -0
  172. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_plot_utils.py +0 -0
  173. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_rcparams.py +0 -0
  174. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -0
  175. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_numba.py +0 -0
  176. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_stats_utils.py +0 -0
  177. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_utils.py +0 -0
  178. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/base_tests/test_utils_numba.py +0 -0
  179. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/conftest.py +0 -0
  180. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/__init__.py +0 -0
  181. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_beanmachine.py +0 -0
  182. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_cmdstan.py +0 -0
  183. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_cmdstanpy.py +0 -0
  184. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_emcee.py +0 -0
  185. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pyjags.py +0 -0
  186. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pyro.py +0 -0
  187. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/external_tests/test_data_pystan.py +0 -0
  188. {arviz-0.21.0 → arviz-0.22.0}/arviz/tests/helpers.py +0 -0
  189. {arviz-0.21.0 → arviz-0.22.0}/arviz/utils.py +0 -0
  190. {arviz-0.21.0 → arviz-0.22.0}/arviz/wrappers/__init__.py +0 -0
  191. {arviz-0.21.0 → arviz-0.22.0}/arviz/wrappers/wrap_pymc.py +0 -0
  192. {arviz-0.21.0 → arviz-0.22.0}/arviz.egg-info/SOURCES.txt +0 -0
  193. {arviz-0.21.0 → arviz-0.22.0}/arviz.egg-info/dependency_links.txt +0 -0
  194. {arviz-0.21.0 → arviz-0.22.0}/arviz.egg-info/top_level.txt +0 -0
  195. {arviz-0.21.0 → arviz-0.22.0}/pyproject.toml +0 -0
  196. {arviz-0.21.0 → arviz-0.22.0}/requirements-docs.txt +0 -0
  197. {arviz-0.21.0 → arviz-0.22.0}/requirements-external.txt +0 -0
  198. {arviz-0.21.0 → arviz-0.22.0}/requirements-test.txt +0 -0
  199. {arviz-0.21.0 → arviz-0.22.0}/setup.cfg +0 -0
  200. {arviz-0.21.0 → arviz-0.22.0}/setup.py +0 -0
@@ -1,5 +1,22 @@
1
1
  # Change Log
2
2
 
3
+ ## v0.22.0 (2025 Jul 9)
4
+
5
+ ### New features
6
+ - `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
7
+ - Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
8
+
9
+ ### Maintenance and fixes
10
+ - `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
11
+ - Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
12
+ - Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
13
+ - Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407))
14
+ - Fix hdi_probs behaviour in 2d `plot_kde` ([2460](https://github.com/arviz-devs/arviz/pull/2460))
15
+
16
+ ### Documentation
17
+ - Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
18
+ - Add migration guide page to help switch over to the new `arviz-xyz` libraries ([2459](https://github.com/arviz-devs/arviz/pull/2459))
19
+
3
20
  ## v0.21.0 (2025 Mar 06)
4
21
 
5
22
  ### New features
@@ -8,7 +25,7 @@
8
25
  - Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
9
26
  - Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
10
27
  - Splits Bayes Factor computation out from `az.plot_bf` into `az.bayes_factor` ([2402](https://github.com/arviz-devs/arviz/issues/2402))
11
- - Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
28
+ - Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
12
29
  - Add exception in `az.plot_hdi` for `x` of type `str` ([2413](https://github.com/arviz-devs/arviz/pull/2413))
13
30
 
14
31
  ### Documentation
@@ -1,8 +1,9 @@
1
1
  # Contributing to ArviZ
2
2
  This document outlines only the most common contributions.
3
3
  Please see the [Contributing guide](https://python.arviz.org/en/latest/contributing/index.html)
4
- on our documentation for a better view of how can you contribute to ArviZ.
4
+ on our documentation for a better view of how you can contribute to ArviZ.
5
5
  We welcome a wide range of contributions, not only code!
6
+ Even improving documentation or fixing typos is a valuable contribution to ArviZ.
6
7
 
7
8
  ## Reporting issues
8
9
  If you encounter any bug or incorrect behaviour while using ArviZ,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: arviz
3
- Version: 0.21.0
3
+ Version: 0.22.0
4
4
  Summary: Exploratory analysis of Bayesian models
5
5
  Home-page: http://github.com/arviz-devs/arviz
6
6
  Author: ArviZ Developers
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.21.0"
3
+ __version__ = "0.22.0"
4
4
 
5
5
  import logging
6
6
  import os
@@ -8,6 +8,7 @@ import os
8
8
  from matplotlib.colors import LinearSegmentedColormap
9
9
  from matplotlib.pyplot import style
10
10
  import matplotlib as mpl
11
+ from packaging import version
11
12
 
12
13
 
13
14
  class Logger(logging.Logger):
@@ -41,8 +42,12 @@ from . import preview
41
42
 
42
43
  # add ArviZ's styles to matplotlib's styles
43
44
  _arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
44
- style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
45
- style.core.reload_library()
45
+ if version.parse(mpl.__version__) >= version.parse("3.11.0.dev0"):
46
+ style.USER_LIBRARY_PATHS.append(_arviz_style_path)
47
+ style.reload_library()
48
+ else:
49
+ style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
50
+ style.core.reload_library()
46
51
 
47
52
 
48
53
  if not logging.root.handlers:
@@ -532,24 +532,27 @@ class InferenceData(Mapping[str, xr.Dataset]):
532
532
  return filename
533
533
 
534
534
  def to_datatree(self):
535
- """Convert InferenceData object to a :class:`~datatree.DataTree`."""
535
+ """Convert InferenceData object to a :class:`~xarray.DataTree`."""
536
536
  try:
537
- from datatree import DataTree
538
- except ModuleNotFoundError as err:
539
- raise ModuleNotFoundError(
540
- "datatree must be installed in order to use InferenceData.to_datatree"
537
+ from xarray import DataTree
538
+ except ImportError as err:
539
+ raise ImportError(
540
+ "xarray must be have DataTree in order to use InferenceData.to_datatree. "
541
+ "Update to xarray>=2024.11.0"
541
542
  ) from err
542
543
  return DataTree.from_dict({group: ds for group, ds in self.items()})
543
544
 
544
545
  @staticmethod
545
546
  def from_datatree(datatree):
546
- """Create an InferenceData object from a :class:`~datatree.DataTree`.
547
+ """Create an InferenceData object from a :class:`~xarray.DataTree`.
547
548
 
548
549
  Parameters
549
550
  ----------
550
551
  datatree : DataTree
551
552
  """
552
- return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})
553
+ return InferenceData(
554
+ **{group: child.to_dataset() for group, child in datatree.children.items()}
555
+ )
553
556
 
554
557
  def to_dict(self, groups=None, filter_groups=None):
555
558
  """Convert InferenceData to a dictionary following xarray naming conventions.
@@ -797,12 +800,20 @@ class InferenceData(Mapping[str, xr.Dataset]):
797
800
  ----------
798
801
  https://zarr.readthedocs.io/
799
802
  """
800
- try: # Check zarr
803
+ try:
801
804
  import zarr
802
-
803
- assert version.parse(zarr.__version__) >= version.parse("2.5.0")
804
- except (ImportError, AssertionError) as err:
805
- raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
805
+ except ImportError as err:
806
+ raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
807
+ if version.parse(zarr.__version__) < version.parse("2.5.0"):
808
+ raise ImportError(
809
+ "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
810
+ )
811
+ if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
812
+ raise ImportError(
813
+ "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
814
+ "'dt = InfereceData.to_datatree' followed by 'dt.to_zarr()' "
815
+ "(needs xarray>=2024.11.0)"
816
+ )
806
817
 
807
818
  # Check store type and create store if necessary
808
819
  if store is None:
@@ -851,10 +862,18 @@ class InferenceData(Mapping[str, xr.Dataset]):
851
862
  """
852
863
  try:
853
864
  import zarr
854
-
855
- assert version.parse(zarr.__version__) >= version.parse("2.5.0")
856
- except (ImportError, AssertionError) as err:
857
- raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
865
+ except ImportError as err:
866
+ raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
867
+ if version.parse(zarr.__version__) < version.parse("2.5.0"):
868
+ raise ImportError(
869
+ "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
870
+ )
871
+ if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
872
+ raise ImportError(
873
+ "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
874
+ "'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
875
+ "(needs xarray>=2024.11.0)"
876
+ )
858
877
 
859
878
  # Check store type and create store if necessary
860
879
  if isinstance(store, str):
@@ -1531,9 +1550,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
1531
1550
  import xarray as xr
1532
1551
  from xarray_einstats.stats import XrDiscreteRV
1533
1552
  from scipy.stats import poisson
1534
- dist = XrDiscreteRV(poisson)
1535
- log_lik = xr.Dataset()
1536
- log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
1553
+ dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1554
+ log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
1537
1555
  idata2.add_groups({"log_likelihood": log_lik})
1538
1556
  idata2
1539
1557
 
@@ -4,7 +4,7 @@ from .inference_data import InferenceData
4
4
 
5
5
 
6
6
  def to_datatree(data):
7
- """Convert InferenceData object to a :class:`~datatree.DataTree`.
7
+ """Convert InferenceData object to a :class:`~xarray.DataTree`.
8
8
 
9
9
  Parameters
10
10
  ----------
@@ -14,7 +14,7 @@ def to_datatree(data):
14
14
 
15
15
 
16
16
  def from_datatree(datatree):
17
- """Create an InferenceData object from a :class:`~datatree.DataTree`.
17
+ """Create an InferenceData object from a :class:`~xarray.DataTree`.
18
18
 
19
19
  Parameters
20
20
  ----------
@@ -1,7 +1,8 @@
1
1
  """NumPyro-specific conversion code."""
2
2
 
3
+ from collections import defaultdict
3
4
  import logging
4
- from typing import Callable, Optional
5
+ from typing import Any, Callable, Optional, Dict, List, Tuple
5
6
 
6
7
  import numpy as np
7
8
 
@@ -13,6 +14,70 @@ from .inference_data import InferenceData
13
14
  _log = logging.getLogger(__name__)
14
15
 
15
16
 
17
+ def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
18
+ merged = defaultdict(list)
19
+
20
+ for k, v in dims_a.items():
21
+ merged[k].extend(v)
22
+
23
+ for k, v in dims_b.items():
24
+ merged[k].extend(v)
25
+
26
+ # Convert back to a regular dict
27
+ return dict(merged)
28
+
29
+
30
+ def infer_dims(
31
+ model: Callable,
32
+ model_args: Optional[Tuple[Any, ...]] = None,
33
+ model_kwargs: Optional[Dict[str, Any]] = None,
34
+ ) -> Dict[str, List[str]]:
35
+
36
+ from numpyro import handlers, distributions as dist
37
+ from numpyro.ops.pytree import PytreeTrace
38
+ from numpyro.infer.initialization import init_to_sample
39
+ import jax
40
+
41
+ model_args = tuple() if model_args is None else model_args
42
+ model_kwargs = dict() if model_args is None else model_kwargs
43
+
44
+ def _get_dist_name(fn):
45
+ if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
46
+ return _get_dist_name(fn.base_dist)
47
+ return type(fn).__name__
48
+
49
+ def get_trace():
50
+ # We use `init_to_sample` to get around ImproperUniform distribution,
51
+ # which does not have `sample` method.
52
+ subs_model = handlers.substitute(
53
+ handlers.seed(model, 0),
54
+ substitute_fn=init_to_sample,
55
+ )
56
+ trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
57
+ # Work around an issue where jax.eval_shape does not work
58
+ # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
59
+ # Here we will remove `fn` and store its name in the trace.
60
+ for _, site in trace.items():
61
+ if site["type"] == "sample":
62
+ site["fn_name"] = _get_dist_name(site.pop("fn"))
63
+ elif site["type"] == "deterministic":
64
+ site["fn_name"] = "Deterministic"
65
+ return PytreeTrace(trace)
66
+
67
+ # We use eval_shape to avoid any array computation.
68
+ trace = jax.eval_shape(get_trace).trace
69
+
70
+ named_dims = {}
71
+
72
+ for name, site in trace.items():
73
+ batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
74
+ event_dims = list(site.get("infer", {}).get("event_dims", []))
75
+ if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
76
+ named_dims[name] = batch_dims + event_dims
77
+
78
+ return named_dims
79
+
80
+
16
81
  class NumPyroConverter:
17
82
  """Encapsulate NumPyro specific logic."""
18
83
 
@@ -36,6 +101,7 @@ class NumPyroConverter:
36
101
  coords=None,
37
102
  dims=None,
38
103
  pred_dims=None,
104
+ extra_event_dims=None,
39
105
  num_chains=1,
40
106
  ):
41
107
  """Convert NumPyro data into an InferenceData object.
@@ -58,9 +124,12 @@ class NumPyroConverter:
58
124
  coords : dict[str] -> list[str]
59
125
  Map of dimensions to coordinates
60
126
  dims : dict[str] -> list[str]
61
- Map variable names to their coordinates
127
+ Map variable names to their coordinates. Will be inferred if they are not provided.
62
128
  pred_dims: dict
63
129
  Dims for predictions data. Map variable names to their coordinates.
130
+ extra_event_dims: dict
131
+ Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
132
+ their coordinates.
64
133
  num_chains: int
65
134
  Number of chains used for sampling. Ignored if posterior is present.
66
135
  """
@@ -80,6 +149,7 @@ class NumPyroConverter:
80
149
  self.coords = coords
81
150
  self.dims = dims
82
151
  self.pred_dims = pred_dims
152
+ self.extra_event_dims = extra_event_dims
83
153
  self.numpyro = numpyro
84
154
 
85
155
  def arbitrary_element(dct):
@@ -107,6 +177,10 @@ class NumPyroConverter:
107
177
  # model arguments and keyword arguments
108
178
  self._args = self.posterior._args # pylint: disable=protected-access
109
179
  self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
180
+ self.dims = self.dims if self.dims is not None else self.infer_dims()
181
+ self.pred_dims = (
182
+ self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
183
+ )
110
184
  else:
111
185
  self.nchains = num_chains
112
186
  get_from = None
@@ -325,6 +399,23 @@ class NumPyroConverter:
325
399
  }
326
400
  )
327
401
 
402
+ @requires("posterior")
403
+ @requires("model")
404
+ def infer_dims(self) -> Dict[str, List[str]]:
405
+ dims = infer_dims(self.model, self._args, self._kwargs)
406
+ if self.extra_event_dims:
407
+ dims = _add_dims(dims, self.extra_event_dims)
408
+ return dims
409
+
410
+ @requires("posterior")
411
+ @requires("model")
412
+ @requires("predictions")
413
+ def infer_pred_dims(self) -> Dict[str, List[str]]:
414
+ dims = infer_dims(self.model, self._args, self._kwargs)
415
+ if self.extra_event_dims:
416
+ dims = _add_dims(dims, self.extra_event_dims)
417
+ return dims
418
+
328
419
 
329
420
  def from_numpyro(
330
421
  posterior=None,
@@ -339,10 +430,25 @@ def from_numpyro(
339
430
  coords=None,
340
431
  dims=None,
341
432
  pred_dims=None,
433
+ extra_event_dims=None,
342
434
  num_chains=1,
343
435
  ):
344
436
  """Convert NumPyro data into an InferenceData object.
345
437
 
438
+ If no dims are provided, this will infer batch dim names from NumPyro model plates.
439
+ For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
440
+ can be provided in numpyro.sample, i.e.::
441
+
442
+ # equivalent to dims entry, {"gamma": ["groups"]}
443
+ gamma = numpyro.sample(
444
+ "gamma",
445
+ dist.ZeroSumNormal(1, event_shape=(n_groups,)),
446
+ infer={"event_dims":["groups"]}
447
+ )
448
+
449
+ There is also an additional `extra_event_dims` input to cover any edge cases, for instance
450
+ deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
451
+
346
452
  For a usage example read the
347
453
  :ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
348
454
 
@@ -364,9 +470,10 @@ def from_numpyro(
364
470
  coords : dict[str] -> list[str]
365
471
  Map of dimensions to coordinates
366
472
  dims : dict[str] -> list[str]
367
- Map variable names to their coordinates
473
+ Map variable names to their coordinates. Will be inferred if they are not provided.
368
474
  pred_dims: dict
369
- Dims for predictions data. Map variable names to their coordinates.
475
+ Dims for predictions data. Map variable names to their coordinates. Default behavior is to
476
+ infer dims if this is not provided
370
477
  num_chains: int
371
478
  Number of chains used for sampling. Ignored if posterior is present.
372
479
  """
@@ -382,5 +489,6 @@ def from_numpyro(
382
489
  coords=coords,
383
490
  dims=dims,
384
491
  pred_dims=pred_dims,
492
+ extra_event_dims=extra_event_dims,
385
493
  num_chains=num_chains,
386
494
  ).to_inference_data()
@@ -4,7 +4,7 @@ from ..data import convert_to_dataset
4
4
  from ..labels import BaseLabeller
5
5
  from ..sel_utils import xarray_var_iter
6
6
  from ..rcparams import rcParams
7
- from ..utils import _var_names
7
+ from ..utils import _var_names, get_coords
8
8
  from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
9
9
 
10
10
 
@@ -14,6 +14,7 @@ def plot_autocorr(
14
14
  filter_vars=None,
15
15
  max_lag=None,
16
16
  combined=False,
17
+ coords=None,
17
18
  grid=None,
18
19
  figsize=None,
19
20
  textsize=None,
@@ -42,6 +43,8 @@ def plot_autocorr(
42
43
  interpret `var_names` as substrings of the real variables names. If "regex",
43
44
  interpret `var_names` as regular expressions on the real variables names. See
44
45
  :ref:`this section <common_filter_vars>` for usage examples.
46
+ coords: mapping, optional
47
+ Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
45
48
  max_lag : int, optional
46
49
  Maximum lag to calculate autocorrelation. By Default, the plot displays the
47
50
  first 100 lag or the total number of draws, whichever is smaller.
@@ -124,11 +127,18 @@ def plot_autocorr(
124
127
  if max_lag is None:
125
128
  max_lag = min(100, data["draw"].shape[0])
126
129
 
130
+ if coords is None:
131
+ coords = {}
132
+
127
133
  if labeller is None:
128
134
  labeller = BaseLabeller()
129
135
 
130
136
  plotters = filter_plotters_list(
131
- list(xarray_var_iter(data, var_names, combined, dim_order=["chain", "draw"])),
137
+ list(
138
+ xarray_var_iter(
139
+ get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
140
+ )
141
+ ),
132
142
  "plot_autocorr",
133
143
  )
134
144
  rows, cols = default_grid(len(plotters), grid=grid)
@@ -21,9 +21,13 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
21
21
  plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
22
22
  plot_kwargs.setdefault("alpha", 0)
23
23
 
24
- fill_kwargs = {} if fill_kwargs is None else fill_kwargs
25
- fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
26
- fill_kwargs.setdefault("alpha", 0.5)
24
+ fill_kwargs = {} if fill_kwargs is None else fill_kwargs.copy()
25
+ # Convert matplotlib color to bokeh fill_color if needed
26
+ if "color" in fill_kwargs and "fill_color" not in fill_kwargs:
27
+ fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.pop("color"))
28
+ else:
29
+ fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.get("fill_color", color))
30
+ fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0.5))
27
31
 
28
32
  figsize, *_ = _scale_fig_size(figsize, None)
29
33
 
@@ -38,9 +42,6 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
38
42
  plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
39
43
  plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))
40
44
 
41
- fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
42
- fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))
43
-
44
45
  ax.patch(
45
46
  np.concatenate((x_data, x_data[::-1])),
46
47
  np.concatenate((y_data[:, 0], y_data[:, 1][::-1])),
@@ -68,7 +68,13 @@ def plot_lm(
68
68
 
69
69
  if y_hat_fill_kwargs is None:
70
70
  y_hat_fill_kwargs = {}
71
- y_hat_fill_kwargs.setdefault("color", "orange")
71
+ else:
72
+ y_hat_fill_kwargs = y_hat_fill_kwargs.copy()
73
+ # Convert matplotlib color to bokeh fill_color if needed
74
+ if "color" in y_hat_fill_kwargs and "fill_color" not in y_hat_fill_kwargs:
75
+ y_hat_fill_kwargs["fill_color"] = y_hat_fill_kwargs.pop("color")
76
+ y_hat_fill_kwargs.setdefault("fill_color", "orange")
77
+ y_hat_fill_kwargs.setdefault("fill_alpha", 0.5)
72
78
 
73
79
  if y_model_plot_kwargs is None:
74
80
  y_model_plot_kwargs = {}
@@ -78,8 +84,13 @@ def plot_lm(
78
84
 
79
85
  if y_model_fill_kwargs is None:
80
86
  y_model_fill_kwargs = {}
81
- y_model_fill_kwargs.setdefault("color", "black")
82
- y_model_fill_kwargs.setdefault("alpha", 0.5)
87
+ else:
88
+ y_model_fill_kwargs = y_model_fill_kwargs.copy()
89
+ # Convert matplotlib color to bokeh fill_color if needed
90
+ if "color" in y_model_fill_kwargs and "fill_color" not in y_model_fill_kwargs:
91
+ y_model_fill_kwargs["fill_color"] = y_model_fill_kwargs.pop("color")
92
+ y_model_fill_kwargs.setdefault("fill_color", "black")
93
+ y_model_fill_kwargs.setdefault("fill_alpha", 0.5)
83
94
 
84
95
  if y_model_mean_kwargs is None:
85
96
  y_model_mean_kwargs = {}
@@ -149,6 +160,11 @@ def plot_lm(
149
160
  )
150
161
 
151
162
  y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
163
+ # Plot mean line across all x values instead of just edges
164
+ mean_legend = ax_i.line(x_plotters, y_model_mean, **y_model_mean_kwargs)
165
+ legend_it.append(("Mean", [mean_legend]))
166
+ continue # Skip the edge plotting since we plotted full line
167
+
152
168
  x_plotters_edge = [min(x_plotters), max(x_plotters)]
153
169
  y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
154
170
  mean_legend = ax_i.line(x_plotters_edge, y_model_mean_edge, **y_model_mean_kwargs)
@@ -37,6 +37,8 @@ def plot_pair(
37
37
  diverging_mask,
38
38
  divergences_kwargs,
39
39
  flat_var_names,
40
+ flat_ref_slices,
41
+ flat_var_labels,
40
42
  backend_kwargs,
41
43
  marginal_kwargs,
42
44
  show,
@@ -72,50 +74,12 @@ def plot_pair(
72
74
  kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)
73
75
 
74
76
  if reference_values:
75
- reference_values_copy = {}
76
- label = []
77
- for variable in list(reference_values.keys()):
78
- if " " in variable:
79
- variable_copy = variable.replace(" ", "\n", 1)
80
- else:
81
- variable_copy = variable
82
-
83
- label.append(variable_copy)
84
- reference_values_copy[variable_copy] = reference_values[variable]
85
-
86
- difference = set(flat_var_names).difference(set(label))
87
-
88
- if difference:
89
- warn = [diff.replace("\n", " ", 1) for diff in difference]
90
- warnings.warn(
91
- "Argument reference_values does not include reference value for: {}".format(
92
- ", ".join(warn)
93
- ),
94
- UserWarning,
95
- )
96
-
97
- if reference_values:
98
- reference_values_copy = {}
99
- label = []
100
- for variable in list(reference_values.keys()):
101
- if " " in variable:
102
- variable_copy = variable.replace(" ", "\n", 1)
103
- else:
104
- variable_copy = variable
105
-
106
- label.append(variable_copy)
107
- reference_values_copy[variable_copy] = reference_values[variable]
108
-
109
- difference = set(flat_var_names).difference(set(label))
110
-
111
- for dif in difference:
112
- reference_values_copy[dif] = None
77
+ difference = set(flat_var_names).difference(set(reference_values.keys()))
113
78
 
114
79
  if difference:
115
- warn = [dif.replace("\n", " ", 1) for dif in difference]
116
80
  warnings.warn(
117
81
  "Argument reference_values does not include reference value for: {}".format(
118
- ", ".join(warn)
82
+ ", ".join(difference)
119
83
  ),
120
84
  UserWarning,
121
85
  )
@@ -262,8 +226,8 @@ def plot_pair(
262
226
  **marginal_kwargs,
263
227
  )
264
228
 
265
- ax[j, i].xaxis.axis_label = flat_var_names[i]
266
- ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
229
+ ax[j, i].xaxis.axis_label = flat_var_labels[i]
230
+ ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
267
231
 
268
232
  elif j + marginals_offset > i:
269
233
  if "scatter" in kind:
@@ -346,12 +310,18 @@ def plot_pair(
346
310
  ax[-1, -1].add_layout(ax_pe_hline)
347
311
 
348
312
  if reference_values:
349
- x = reference_values_copy[flat_var_names[j + marginals_offset]]
350
- y = reference_values_copy[flat_var_names[i]]
351
- if x and y:
352
- ax[j, i].scatter(y, x, **reference_values_kwargs)
353
- ax[j, i].xaxis.axis_label = flat_var_names[i]
354
- ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
313
+ x_name = flat_var_names[j + marginals_offset]
314
+ y_name = flat_var_names[i]
315
+ if (x_name not in difference) and (y_name not in difference):
316
+ ax[j, i].scatter(
317
+ np.array(reference_values[y_name])[flat_ref_slices[i]],
318
+ np.array(reference_values[x_name])[
319
+ flat_ref_slices[j + marginals_offset]
320
+ ],
321
+ **reference_values_kwargs,
322
+ )
323
+ ax[j, i].xaxis.axis_label = flat_var_labels[i]
324
+ ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
355
325
 
356
326
  show_layout(ax, show)
357
327
 
@@ -7,6 +7,7 @@ from matplotlib import cm
7
7
  import matplotlib.pyplot as plt
8
8
  import numpy as np
9
9
  from matplotlib.colors import to_rgba_array
10
+ from packaging import version
10
11
 
11
12
  from ....stats.density_utils import histogram
12
13
  from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
@@ -39,7 +40,13 @@ def plot_khat(
39
40
  show,
40
41
  ):
41
42
  """Matplotlib khat plot."""
42
- if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk:
43
+ if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
44
+ interactive_backends = mpl.backends.backend_registry.list_builtin(
45
+ mpl.backends.BackendFilter.INTERACTIVE
46
+ )
47
+ else:
48
+ interactive_backends = mpl.rcsetup.interactive_bk
49
+ if hover_label and mpl.get_backend() not in interactive_backends:
43
50
  hover_label = False
44
51
  warnings.warn(
45
52
  "hover labels are only available with interactive backends. To switch to an "
@@ -115,12 +115,18 @@ def plot_lm(
115
115
 
116
116
  if y_model is not None:
117
117
  _, _, _, y_model_plotters = y_model[i]
118
+
118
119
  if kind_model == "lines":
119
- for j in range(num_samples):
120
- ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs)
121
- ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
120
+ # y_model_plotters should be (points, samples)
121
+ y_points = y_model_plotters.shape[0]
122
+ if x_plotters.shape[0] == y_points:
123
+ for j in range(num_samples):
124
+ ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
125
+
126
+ ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
127
+ y_model_mean = np.mean(y_model_plotters, axis=1)
128
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
122
129
 
123
- y_model_mean = np.mean(y_model_plotters, axis=1)
124
130
  else:
125
131
  plot_hdi(
126
132
  x_plotters,
@@ -128,10 +134,10 @@ def plot_lm(
128
134
  fill_kwargs=y_model_fill_kwargs,
129
135
  ax=ax_i,
130
136
  )
131
- ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
132
137
 
133
- y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
134
- ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
138
+ ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
139
+ y_model_mean = np.mean(y_model_plotters, axis=0)
140
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
135
141
 
136
142
  if legend:
137
143
  ax_i.legend(fontsize=xt_labelsize, loc="upper left")