arviz 0.21.0__tar.gz → 0.23.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (201) hide show
  1. {arviz-0.21.0 → arviz-0.23.0}/CHANGELOG.md +26 -1
  2. {arviz-0.21.0 → arviz-0.23.0}/CONTRIBUTING.md +2 -1
  3. {arviz-0.21.0 → arviz-0.23.0}/PKG-INFO +36 -3
  4. {arviz-0.21.0 → arviz-0.23.0}/arviz/__init__.py +49 -4
  5. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/converters.py +11 -0
  6. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/inference_data.py +46 -24
  7. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_datatree.py +2 -2
  8. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_numpyro.py +116 -5
  9. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pyjags.py +1 -1
  10. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/autocorrplot.py +12 -2
  11. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/hdiplot.py +7 -6
  12. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/lmplot.py +19 -3
  13. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/pairplot.py +18 -48
  14. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/khatplot.py +8 -1
  15. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/lmplot.py +13 -7
  16. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/pairplot.py +14 -22
  17. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/bpvplot.py +1 -1
  18. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/dotplot.py +2 -0
  19. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/forestplot.py +16 -4
  20. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/lmplot.py +41 -14
  21. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/pairplot.py +10 -3
  22. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/ppcplot.py +1 -1
  23. arviz-0.23.0/arviz/preview.py +58 -0
  24. {arviz-0.21.0 → arviz-0.23.0}/arviz/rcparams.py +2 -2
  25. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/density_utils.py +1 -1
  26. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats.py +31 -34
  27. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_data.py +25 -4
  28. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  29. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plots_matplotlib.py +94 -1
  30. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats.py +42 -1
  31. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_ecdf_utils.py +2 -2
  32. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_numpyro.py +154 -4
  33. {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/base.py +1 -1
  34. {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/wrap_stan.py +1 -1
  35. {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/PKG-INFO +36 -3
  36. {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/requires.txt +6 -6
  37. {arviz-0.21.0 → arviz-0.23.0}/requirements-dev.txt +1 -0
  38. {arviz-0.21.0 → arviz-0.23.0}/requirements-optional.txt +1 -1
  39. {arviz-0.21.0 → arviz-0.23.0}/requirements.txt +5 -5
  40. arviz-0.21.0/arviz/preview.py +0 -48
  41. {arviz-0.21.0 → arviz-0.23.0}/CODE_OF_CONDUCT.md +0 -0
  42. {arviz-0.21.0 → arviz-0.23.0}/GOVERNANCE.md +0 -0
  43. {arviz-0.21.0 → arviz-0.23.0}/LICENSE +0 -0
  44. {arviz-0.21.0 → arviz-0.23.0}/MANIFEST.in +0 -0
  45. {arviz-0.21.0 → arviz-0.23.0}/README.md +0 -0
  46. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/__init__.py +0 -0
  47. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/base.py +0 -0
  48. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/datasets.py +0 -0
  49. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/code/radon/radon.json +0 -0
  50. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data/centered_eight.nc +0 -0
  51. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data/non_centered_eight.nc +0 -0
  52. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data_local.json +0 -0
  53. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/example_data/data_remote.json +0 -0
  54. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_beanmachine.py +0 -0
  55. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_cmdstan.py +0 -0
  56. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_cmdstanpy.py +0 -0
  57. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_dict.py +0 -0
  58. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_emcee.py +0 -0
  59. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_json.py +0 -0
  60. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_netcdf.py +0 -0
  61. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pyro.py +0 -0
  62. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_pystan.py +0 -0
  63. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/io_zarr.py +0 -0
  64. {arviz-0.21.0 → arviz-0.23.0}/arviz/data/utils.py +0 -0
  65. {arviz-0.21.0 → arviz-0.23.0}/arviz/labels.py +0 -0
  66. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/__init__.py +0 -0
  67. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/__init__.py +0 -0
  68. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/__init__.py +0 -0
  69. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/autocorrplot.py +0 -0
  70. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/bfplot.py +0 -0
  71. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/bpvplot.py +0 -0
  72. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/compareplot.py +0 -0
  73. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/densityplot.py +0 -0
  74. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/distcomparisonplot.py +0 -0
  75. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/distplot.py +0 -0
  76. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/dotplot.py +0 -0
  77. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/ecdfplot.py +0 -0
  78. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/elpdplot.py +0 -0
  79. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/energyplot.py +0 -0
  80. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/essplot.py +0 -0
  81. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/forestplot.py +0 -0
  82. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/kdeplot.py +0 -0
  83. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/khatplot.py +0 -0
  84. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/loopitplot.py +0 -0
  85. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/mcseplot.py +0 -0
  86. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/parallelplot.py +0 -0
  87. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/posteriorplot.py +0 -0
  88. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/ppcplot.py +0 -0
  89. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/rankplot.py +0 -0
  90. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/separationplot.py +0 -0
  91. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/traceplot.py +0 -0
  92. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/bokeh/violinplot.py +0 -0
  93. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/__init__.py +0 -0
  94. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/autocorrplot.py +0 -0
  95. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/bfplot.py +0 -0
  96. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/bpvplot.py +0 -0
  97. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/compareplot.py +0 -0
  98. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/densityplot.py +0 -0
  99. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -0
  100. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/distplot.py +0 -0
  101. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/dotplot.py +0 -0
  102. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/ecdfplot.py +0 -0
  103. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/elpdplot.py +0 -0
  104. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/energyplot.py +0 -0
  105. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/essplot.py +0 -0
  106. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/forestplot.py +0 -0
  107. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/hdiplot.py +0 -0
  108. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/kdeplot.py +0 -0
  109. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/loopitplot.py +0 -0
  110. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/mcseplot.py +0 -0
  111. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/parallelplot.py +0 -0
  112. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/posteriorplot.py +0 -0
  113. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/ppcplot.py +0 -0
  114. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/rankplot.py +0 -0
  115. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/separationplot.py +0 -0
  116. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/traceplot.py +0 -0
  117. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/tsplot.py +0 -0
  118. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/backends/matplotlib/violinplot.py +0 -0
  119. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/bfplot.py +0 -0
  120. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/compareplot.py +0 -0
  121. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/densityplot.py +0 -0
  122. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/distcomparisonplot.py +0 -0
  123. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/distplot.py +0 -0
  124. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/ecdfplot.py +0 -0
  125. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/elpdplot.py +0 -0
  126. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/energyplot.py +0 -0
  127. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/essplot.py +0 -0
  128. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/hdiplot.py +0 -0
  129. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/kdeplot.py +4 -4
  130. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/khatplot.py +0 -0
  131. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/loopitplot.py +0 -0
  132. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/mcseplot.py +0 -0
  133. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/parallelplot.py +0 -0
  134. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/plot_utils.py +0 -0
  135. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/posteriorplot.py +0 -0
  136. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/rankplot.py +0 -0
  137. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/separationplot.py +0 -0
  138. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-bluish.mplstyle +0 -0
  139. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-brownish.mplstyle +0 -0
  140. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-colors.mplstyle +0 -0
  141. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-cyanish.mplstyle +0 -0
  142. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-darkgrid.mplstyle +0 -0
  143. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-doc.mplstyle +0 -0
  144. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-docgrid.mplstyle +0 -0
  145. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-grayscale.mplstyle +0 -0
  146. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-greenish.mplstyle +0 -0
  147. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-orangish.mplstyle +0 -0
  148. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-plasmish.mplstyle +0 -0
  149. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-purplish.mplstyle +0 -0
  150. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-redish.mplstyle +0 -0
  151. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-royish.mplstyle +0 -0
  152. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-viridish.mplstyle +0 -0
  153. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-white.mplstyle +0 -0
  154. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/styles/arviz-whitegrid.mplstyle +0 -0
  155. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/traceplot.py +0 -0
  156. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/tsplot.py +0 -0
  157. {arviz-0.21.0 → arviz-0.23.0}/arviz/plots/violinplot.py +0 -0
  158. {arviz-0.21.0 → arviz-0.23.0}/arviz/py.typed +0 -0
  159. {arviz-0.21.0 → arviz-0.23.0}/arviz/sel_utils.py +0 -0
  160. {arviz-0.21.0 → arviz-0.23.0}/arviz/static/css/style.css +0 -0
  161. {arviz-0.21.0 → arviz-0.23.0}/arviz/static/html/icons-svg-inline.html +0 -0
  162. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/__init__.py +0 -0
  163. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/diagnostics.py +0 -0
  164. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/ecdf_utils.py +0 -0
  165. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats_refitting.py +0 -0
  166. {arviz-0.21.0 → arviz-0.23.0}/arviz/stats/stats_utils.py +0 -0
  167. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/__init__.py +0 -0
  168. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/__init__.py +0 -0
  169. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_data_zarr.py +0 -0
  170. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_diagnostics.py +0 -0
  171. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_diagnostics_numba.py +0 -0
  172. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_helpers.py +0 -0
  173. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_labels.py +0 -0
  174. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_plot_utils.py +0 -0
  175. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_rcparams.py +0 -0
  176. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_numba.py +0 -0
  177. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_stats_utils.py +0 -0
  178. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_utils.py +0 -0
  179. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/base_tests/test_utils_numba.py +0 -0
  180. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/conftest.py +0 -0
  181. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/__init__.py +0 -0
  182. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_beanmachine.py +0 -0
  183. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_cmdstan.py +0 -0
  184. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_cmdstanpy.py +0 -0
  185. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_emcee.py +0 -0
  186. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pyjags.py +0 -0
  187. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pyro.py +0 -0
  188. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/external_tests/test_data_pystan.py +0 -0
  189. {arviz-0.21.0 → arviz-0.23.0}/arviz/tests/helpers.py +0 -0
  190. {arviz-0.21.0 → arviz-0.23.0}/arviz/utils.py +0 -0
  191. {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/__init__.py +0 -0
  192. {arviz-0.21.0 → arviz-0.23.0}/arviz/wrappers/wrap_pymc.py +0 -0
  193. {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/SOURCES.txt +0 -0
  194. {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/dependency_links.txt +0 -0
  195. {arviz-0.21.0 → arviz-0.23.0}/arviz.egg-info/top_level.txt +0 -0
  196. {arviz-0.21.0 → arviz-0.23.0}/pyproject.toml +0 -0
  197. {arviz-0.21.0 → arviz-0.23.0}/requirements-docs.txt +0 -0
  198. {arviz-0.21.0 → arviz-0.23.0}/requirements-external.txt +0 -0
  199. {arviz-0.21.0 → arviz-0.23.0}/requirements-test.txt +0 -0
  200. {arviz-0.21.0 → arviz-0.23.0}/setup.cfg +0 -0
  201. {arviz-0.21.0 → arviz-0.23.0}/setup.py +0 -0
@@ -1,5 +1,30 @@
1
1
  # Change Log
2
2
 
3
+ ## v0.23.0 (2025 Des 9)
4
+
5
+ ### Maintenance and fixes
6
+ - Fix numpyro jax incompatibility. ([2465](https://github.com/arviz-devs/arviz/pull/2465))
7
+ - Avoid closing unloaded files in `from_netcdf()` ([2463](https://github.com/arviz-devs/arviz/issues/2463))
8
+ - Fix sign error in lp parsed in from_numpyro ([2468](https://github.com/arviz-devs/arviz/issues/2468))
9
+ - Fix attrs persistance in idata-datatree conversions ([2476](https://github.com/arviz-devs/arviz/issues/2476))
10
+
11
+ ## v0.22.0 (2025 Jul 9)
12
+
13
+ ### New features
14
+ - `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
15
+ - Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
16
+
17
+ ### Maintenance and fixes
18
+ - `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
19
+ - Fix `plot_lm` for multidimensional data ([2408](https://github.com/arviz-devs/arviz/issues/2408))
20
+ - Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
21
+ - Test compare dataframe stays consistent independently of input order ([2407](https://github.com/arviz-devs/arviz/pull/2407))
22
+ - Fix hdi_probs behaviour in 2d `plot_kde` ([2460](https://github.com/arviz-devs/arviz/pull/2460))
23
+
24
+ ### Documentation
25
+ - Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
26
+ - Add migration guide page to help switch over to the new `arviz-xyz` libraries ([2459](https://github.com/arviz-devs/arviz/pull/2459))
27
+
3
28
  ## v0.21.0 (2025 Mar 06)
4
29
 
5
30
  ### New features
@@ -8,7 +33,7 @@
8
33
  - Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))
9
34
  - Only emit a warning for custom groups in `InferenceData` when explicitly requested ([2401](https://github.com/arviz-devs/arviz/pull/2401))
10
35
  - Splits Bayes Factor computation out from `az.plot_bf` into `az.bayes_factor` ([2402](https://github.com/arviz-devs/arviz/issues/2402))
11
- - Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
36
+ - Update `method="sd"` of `mcse` to not use normality assumption ([2167](https://github.com/arviz-devs/arviz/pull/2167))
12
37
  - Add exception in `az.plot_hdi` for `x` of type `str` ([2413](https://github.com/arviz-devs/arviz/pull/2413))
13
38
 
14
39
  ### Documentation
@@ -1,8 +1,9 @@
1
1
  # Contributing to ArviZ
2
2
  This document outlines only the most common contributions.
3
3
  Please see the [Contributing guide](https://python.arviz.org/en/latest/contributing/index.html)
4
- on our documentation for a better view of how can you contribute to ArviZ.
4
+ on our documentation for a better view of how you can contribute to ArviZ.
5
5
  We welcome a wide range of contributions, not only code!
6
+ Even improving documentation or fixing typos is a valuable contribution to ArviZ.
6
7
 
7
8
  ## Reporting issues
8
9
  If you encounter any bug or incorrect behaviour while using ArviZ,
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: arviz
3
- Version: 0.21.0
3
+ Version: 0.23.0
4
4
  Summary: Exploratory analysis of Bayesian models
5
5
  Home-page: http://github.com/arviz-devs/arviz
6
6
  Author: ArviZ Developers
@@ -20,9 +20,42 @@ Classifier: Topic :: Scientific/Engineering :: Visualization
20
20
  Classifier: Topic :: Scientific/Engineering :: Mathematics
21
21
  Requires-Python: >=3.10
22
22
  Description-Content-Type: text/markdown
23
+ License-File: LICENSE
24
+ Requires-Dist: setuptools>=60.0.0
25
+ Requires-Dist: matplotlib>=3.8
26
+ Requires-Dist: numpy>=1.26.0
27
+ Requires-Dist: scipy>=1.11.0
28
+ Requires-Dist: packaging
29
+ Requires-Dist: pandas>=2.1.0
30
+ Requires-Dist: xarray>=2023.7.0
31
+ Requires-Dist: h5netcdf>=1.0.2
32
+ Requires-Dist: typing_extensions>=4.1.0
33
+ Requires-Dist: xarray-einstats>=0.3
23
34
  Provides-Extra: all
35
+ Requires-Dist: numba; extra == "all"
36
+ Requires-Dist: netcdf4; extra == "all"
37
+ Requires-Dist: bokeh>=3; extra == "all"
38
+ Requires-Dist: contourpy; extra == "all"
39
+ Requires-Dist: ujson; extra == "all"
40
+ Requires-Dist: dask[distributed]; extra == "all"
41
+ Requires-Dist: zarr<3,>=2.5.0; extra == "all"
42
+ Requires-Dist: xarray>=2024.11.0; extra == "all"
43
+ Requires-Dist: dm-tree>=0.1.8; extra == "all"
24
44
  Provides-Extra: preview
25
- License-File: LICENSE
45
+ Requires-Dist: arviz-base[h5netcdf]; extra == "preview"
46
+ Requires-Dist: arviz-stats[xarray]; extra == "preview"
47
+ Requires-Dist: arviz-plots; extra == "preview"
48
+ Dynamic: author
49
+ Dynamic: classifier
50
+ Dynamic: description
51
+ Dynamic: description-content-type
52
+ Dynamic: home-page
53
+ Dynamic: license
54
+ Dynamic: license-file
55
+ Dynamic: provides-extra
56
+ Dynamic: requires-dist
57
+ Dynamic: requires-python
58
+ Dynamic: summary
26
59
 
27
60
  <img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
28
61
  <img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ_white.png#gh-dark-mode-only" width=200></img>
@@ -1,13 +1,54 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.21.0"
3
+ __version__ = "0.23.0"
4
4
 
5
5
  import logging
6
6
  import os
7
+ import re
7
8
 
8
9
  from matplotlib.colors import LinearSegmentedColormap
9
10
  from matplotlib.pyplot import style
10
11
  import matplotlib as mpl
12
+ from packaging import version
13
+
14
+
15
+ def _warn_once_per_day():
16
+ from .preview import info
17
+
18
+ # skip warning if all 3 arviz subpackages are already installed
19
+ pat = re.compile(r"arviz_(base|stats|plots) available")
20
+ if len(pat.findall(info)) == 3:
21
+ return
22
+
23
+ import datetime
24
+ from warnings import warn
25
+ from pathlib import Path
26
+
27
+ warning_dir = Path.home() / "arviz_data"
28
+ warning_dir.mkdir(exist_ok=True)
29
+
30
+ stamp_file = warning_dir / "daily_warning"
31
+ today = datetime.date.today()
32
+
33
+ if stamp_file.exists():
34
+ last_date = datetime.date.fromisoformat(stamp_file.read_text().strip())
35
+ else:
36
+ last_date = None
37
+
38
+ if last_date != today:
39
+ warn(
40
+ "\nArviZ is undergoing a major refactor to improve flexibility and extensibility "
41
+ "while maintaining a user-friendly interface."
42
+ "\nSome upcoming changes may be backward incompatible."
43
+ "\nFor details and migration guidance, visit: "
44
+ "https://python.arviz.org/en/latest/user_guide/migration_guide.html",
45
+ FutureWarning,
46
+ )
47
+
48
+ stamp_file.write_text(today.isoformat())
49
+
50
+
51
+ _warn_once_per_day()
11
52
 
12
53
 
13
54
  class Logger(logging.Logger):
@@ -41,8 +82,12 @@ from . import preview
41
82
 
42
83
  # add ArviZ's styles to matplotlib's styles
43
84
  _arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
44
- style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
45
- style.core.reload_library()
85
+ if version.parse(mpl.__version__) >= version.parse("3.11.0.dev0"):
86
+ style.USER_LIBRARY_PATHS.append(_arviz_style_path)
87
+ style.reload_library()
88
+ else:
89
+ style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
90
+ style.core.reload_library()
46
91
 
47
92
 
48
93
  if not logging.root.handlers:
@@ -328,4 +373,4 @@ except ModuleNotFoundError:
328
373
 
329
374
 
330
375
  # clean namespace
331
- del os, logging, LinearSegmentedColormap, Logger, mpl
376
+ del os, re, logging, version, LinearSegmentedColormap, Logger, mpl
@@ -2,6 +2,7 @@
2
2
 
3
3
  import numpy as np
4
4
  import xarray as xr
5
+ import pandas as pd
5
6
 
6
7
  try:
7
8
  from tree import is_nested
@@ -44,6 +45,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
44
45
  | dict: creates an xarray dataset as the only group
45
46
  | numpy array: creates an xarray dataset as the only group, gives the
46
47
  array an arbitrary name
48
+ | object with __array__: converts to numpy array, then creates an xarray dataset as
49
+ the only group, gives the array an arbitrary name
47
50
  group : str
48
51
  If `obj` is a dict or numpy array, assigns the resulting xarray
49
52
  dataset to this group. Default: "posterior".
@@ -115,6 +118,13 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
115
118
  dataset = dict_to_dataset(obj, coords=coords, dims=dims)
116
119
  elif isinstance(obj, np.ndarray):
117
120
  dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
121
+ elif (
122
+ hasattr(obj, "__array__")
123
+ and callable(getattr(obj, "__array__"))
124
+ and (not isinstance(obj, pd.DataFrame))
125
+ ):
126
+ obj = obj.__array__()
127
+ dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
118
128
  elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
119
129
  if group == "sample_stats":
120
130
  kwargs["posterior"] = kwargs.pop(group)
@@ -129,6 +139,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
129
139
  "pytree (if 'dm-tree' is installed)",
130
140
  "netcdf filename",
131
141
  "numpy array",
142
+ "object with __array__",
132
143
  "pystan fit",
133
144
  "emcee fit",
134
145
  "pyro mcmc fit",
@@ -430,11 +430,12 @@ class InferenceData(Mapping[str, xr.Dataset]):
430
430
  if re.search(key, group):
431
431
  group_kws = kws
432
432
  group_kws.setdefault("engine", engine)
433
- with xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws) as data:
434
- if rcParams["data.load"] == "eager":
433
+ data = xr.open_dataset(filename, group=f"{base_group}/{group}", **group_kws)
434
+ if rcParams["data.load"] == "eager":
435
+ with data:
435
436
  groups[group] = data.load()
436
- else:
437
- groups[group] = data
437
+ else:
438
+ groups[group] = data
438
439
 
439
440
  with xr.open_dataset(filename, engine=engine, group=base_group) as data:
440
441
  attrs.update(data.load().attrs)
@@ -532,24 +533,30 @@ class InferenceData(Mapping[str, xr.Dataset]):
532
533
  return filename
533
534
 
534
535
  def to_datatree(self):
535
- """Convert InferenceData object to a :class:`~datatree.DataTree`."""
536
+ """Convert InferenceData object to a :class:`~xarray.DataTree`."""
536
537
  try:
537
- from datatree import DataTree
538
- except ModuleNotFoundError as err:
539
- raise ModuleNotFoundError(
540
- "datatree must be installed in order to use InferenceData.to_datatree"
538
+ from xarray import DataTree
539
+ except ImportError as err:
540
+ raise ImportError(
541
+ "xarray must be have DataTree in order to use InferenceData.to_datatree. "
542
+ "Update to xarray>=2024.11.0"
541
543
  ) from err
542
- return DataTree.from_dict({group: ds for group, ds in self.items()})
544
+ dt = DataTree.from_dict({group: ds for group, ds in self.items()})
545
+ dt.attrs = self.attrs
546
+ return dt
543
547
 
544
548
  @staticmethod
545
549
  def from_datatree(datatree):
546
- """Create an InferenceData object from a :class:`~datatree.DataTree`.
550
+ """Create an InferenceData object from a :class:`~xarray.DataTree`.
547
551
 
548
552
  Parameters
549
553
  ----------
550
554
  datatree : DataTree
551
555
  """
552
- return InferenceData(**{group: sub_dt.to_dataset() for group, sub_dt in datatree.items()})
556
+ return InferenceData(
557
+ attrs=datatree.attrs,
558
+ **{group: child.to_dataset() for group, child in datatree.children.items()},
559
+ )
553
560
 
554
561
  def to_dict(self, groups=None, filter_groups=None):
555
562
  """Convert InferenceData to a dictionary following xarray naming conventions.
@@ -797,12 +804,20 @@ class InferenceData(Mapping[str, xr.Dataset]):
797
804
  ----------
798
805
  https://zarr.readthedocs.io/
799
806
  """
800
- try: # Check zarr
807
+ try:
801
808
  import zarr
802
-
803
- assert version.parse(zarr.__version__) >= version.parse("2.5.0")
804
- except (ImportError, AssertionError) as err:
805
- raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
809
+ except ImportError as err:
810
+ raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
811
+ if version.parse(zarr.__version__) < version.parse("2.5.0"):
812
+ raise ImportError(
813
+ "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
814
+ )
815
+ if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
816
+ raise ImportError(
817
+ "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
818
+ "'dt = InferenceData.to_datatree' followed by 'dt.to_zarr()' "
819
+ "(needs xarray>=2024.11.0)"
820
+ )
806
821
 
807
822
  # Check store type and create store if necessary
808
823
  if store is None:
@@ -851,10 +866,18 @@ class InferenceData(Mapping[str, xr.Dataset]):
851
866
  """
852
867
  try:
853
868
  import zarr
854
-
855
- assert version.parse(zarr.__version__) >= version.parse("2.5.0")
856
- except (ImportError, AssertionError) as err:
857
- raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err
869
+ except ImportError as err:
870
+ raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
871
+ if version.parse(zarr.__version__) < version.parse("2.5.0"):
872
+ raise ImportError(
873
+ "Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
874
+ )
875
+ if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
876
+ raise ImportError(
877
+ "Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
878
+ "'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
879
+ "(needs xarray>=2024.11.0)"
880
+ )
858
881
 
859
882
  # Check store type and create store if necessary
860
883
  if isinstance(store, str):
@@ -1531,9 +1554,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
1531
1554
  import xarray as xr
1532
1555
  from xarray_einstats.stats import XrDiscreteRV
1533
1556
  from scipy.stats import poisson
1534
- dist = XrDiscreteRV(poisson)
1535
- log_lik = xr.Dataset()
1536
- log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
1557
+ dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
1558
+ log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
1537
1559
  idata2.add_groups({"log_likelihood": log_lik})
1538
1560
  idata2
1539
1561
 
@@ -4,7 +4,7 @@ from .inference_data import InferenceData
4
4
 
5
5
 
6
6
  def to_datatree(data):
7
- """Convert InferenceData object to a :class:`~datatree.DataTree`.
7
+ """Convert InferenceData object to a :class:`~xarray.DataTree`.
8
8
 
9
9
  Parameters
10
10
  ----------
@@ -14,7 +14,7 @@ def to_datatree(data):
14
14
 
15
15
 
16
16
  def from_datatree(datatree):
17
- """Create an InferenceData object from a :class:`~datatree.DataTree`.
17
+ """Create an InferenceData object from a :class:`~xarray.DataTree`.
18
18
 
19
19
  Parameters
20
20
  ----------
@@ -1,7 +1,8 @@
1
1
  """NumPyro-specific conversion code."""
2
2
 
3
+ from collections import defaultdict
3
4
  import logging
4
- from typing import Callable, Optional
5
+ from typing import Any, Callable, Optional, Dict, List, Tuple
5
6
 
6
7
  import numpy as np
7
8
 
@@ -13,6 +14,70 @@ from .inference_data import InferenceData
13
14
  _log = logging.getLogger(__name__)
14
15
 
15
16
 
17
+ def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
18
+ merged = defaultdict(list)
19
+
20
+ for k, v in dims_a.items():
21
+ merged[k].extend(v)
22
+
23
+ for k, v in dims_b.items():
24
+ merged[k].extend(v)
25
+
26
+ # Convert back to a regular dict
27
+ return dict(merged)
28
+
29
+
30
+ def infer_dims(
31
+ model: Callable,
32
+ model_args: Optional[Tuple[Any, ...]] = None,
33
+ model_kwargs: Optional[Dict[str, Any]] = None,
34
+ ) -> Dict[str, List[str]]:
35
+
36
+ from numpyro import handlers, distributions as dist
37
+ from numpyro.ops.pytree import PytreeTrace
38
+ from numpyro.infer.initialization import init_to_sample
39
+ import jax
40
+
41
+ model_args = tuple() if model_args is None else model_args
42
+ model_kwargs = dict() if model_args is None else model_kwargs
43
+
44
+ def _get_dist_name(fn):
45
+ if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
46
+ return _get_dist_name(fn.base_dist)
47
+ return type(fn).__name__
48
+
49
+ def get_trace():
50
+ # We use `init_to_sample` to get around ImproperUniform distribution,
51
+ # which does not have `sample` method.
52
+ subs_model = handlers.substitute(
53
+ handlers.seed(model, 0),
54
+ substitute_fn=init_to_sample,
55
+ )
56
+ trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
57
+ # Work around an issue where jax.eval_shape does not work
58
+ # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
59
+ # Here we will remove `fn` and store its name in the trace.
60
+ for _, site in trace.items():
61
+ if site["type"] == "sample":
62
+ site["fn_name"] = _get_dist_name(site.pop("fn"))
63
+ elif site["type"] == "deterministic":
64
+ site["fn_name"] = "Deterministic"
65
+ return PytreeTrace(trace)
66
+
67
+ # We use eval_shape to avoid any array computation.
68
+ trace = jax.eval_shape(get_trace).trace
69
+
70
+ named_dims = {}
71
+
72
+ for name, site in trace.items():
73
+ batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
74
+ event_dims = list(site.get("infer", {}).get("event_dims", []))
75
+ if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
76
+ named_dims[name] = batch_dims + event_dims
77
+
78
+ return named_dims
79
+
80
+
16
81
  class NumPyroConverter:
17
82
  """Encapsulate NumPyro specific logic."""
18
83
 
@@ -36,6 +101,7 @@ class NumPyroConverter:
36
101
  coords=None,
37
102
  dims=None,
38
103
  pred_dims=None,
104
+ extra_event_dims=None,
39
105
  num_chains=1,
40
106
  ):
41
107
  """Convert NumPyro data into an InferenceData object.
@@ -58,9 +124,12 @@ class NumPyroConverter:
58
124
  coords : dict[str] -> list[str]
59
125
  Map of dimensions to coordinates
60
126
  dims : dict[str] -> list[str]
61
- Map variable names to their coordinates
127
+ Map variable names to their coordinates. Will be inferred if they are not provided.
62
128
  pred_dims: dict
63
129
  Dims for predictions data. Map variable names to their coordinates.
130
+ extra_event_dims: dict
131
+ Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
132
+ their coordinates.
64
133
  num_chains: int
65
134
  Number of chains used for sampling. Ignored if posterior is present.
66
135
  """
@@ -80,6 +149,7 @@ class NumPyroConverter:
80
149
  self.coords = coords
81
150
  self.dims = dims
82
151
  self.pred_dims = pred_dims
152
+ self.extra_event_dims = extra_event_dims
83
153
  self.numpyro = numpyro
84
154
 
85
155
  def arbitrary_element(dct):
@@ -107,6 +177,10 @@ class NumPyroConverter:
107
177
  # model arguments and keyword arguments
108
178
  self._args = self.posterior._args # pylint: disable=protected-access
109
179
  self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
180
+ self.dims = self.dims if self.dims is not None else self.infer_dims()
181
+ self.pred_dims = (
182
+ self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
183
+ )
110
184
  else:
111
185
  self.nchains = num_chains
112
186
  get_from = None
@@ -167,7 +241,10 @@ class NumPyroConverter:
167
241
  continue
168
242
  name = rename_key.get(stat, stat)
169
243
  value = value.copy()
170
- data[name] = value
244
+ if stat == "potential_energy":
245
+ data[name] = -value
246
+ else:
247
+ data[name] = value
171
248
  if stat == "num_steps":
172
249
  data["tree_depth"] = np.log2(value).astype(int) + 1
173
250
  return dict_to_dataset(
@@ -325,6 +402,23 @@ class NumPyroConverter:
325
402
  }
326
403
  )
327
404
 
405
+ @requires("posterior")
406
+ @requires("model")
407
+ def infer_dims(self) -> Dict[str, List[str]]:
408
+ dims = infer_dims(self.model, self._args, self._kwargs)
409
+ if self.extra_event_dims:
410
+ dims = _add_dims(dims, self.extra_event_dims)
411
+ return dims
412
+
413
+ @requires("posterior")
414
+ @requires("model")
415
+ @requires("predictions")
416
+ def infer_pred_dims(self) -> Dict[str, List[str]]:
417
+ dims = infer_dims(self.model, self._args, self._kwargs)
418
+ if self.extra_event_dims:
419
+ dims = _add_dims(dims, self.extra_event_dims)
420
+ return dims
421
+
328
422
 
329
423
  def from_numpyro(
330
424
  posterior=None,
@@ -339,10 +433,25 @@ def from_numpyro(
339
433
  coords=None,
340
434
  dims=None,
341
435
  pred_dims=None,
436
+ extra_event_dims=None,
342
437
  num_chains=1,
343
438
  ):
344
439
  """Convert NumPyro data into an InferenceData object.
345
440
 
441
+ If no dims are provided, this will infer batch dim names from NumPyro model plates.
442
+ For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
443
+ can be provided in numpyro.sample, i.e.::
444
+
445
+ # equivalent to dims entry, {"gamma": ["groups"]}
446
+ gamma = numpyro.sample(
447
+ "gamma",
448
+ dist.ZeroSumNormal(1, event_shape=(n_groups,)),
449
+ infer={"event_dims":["groups"]}
450
+ )
451
+
452
+ There is also an additional `extra_event_dims` input to cover any edge cases, for instance
453
+ deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
454
+
346
455
  For a usage example read the
347
456
  :ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
348
457
 
@@ -364,9 +473,10 @@ def from_numpyro(
364
473
  coords : dict[str] -> list[str]
365
474
  Map of dimensions to coordinates
366
475
  dims : dict[str] -> list[str]
367
- Map variable names to their coordinates
476
+ Map variable names to their coordinates. Will be inferred if they are not provided.
368
477
  pred_dims: dict
369
- Dims for predictions data. Map variable names to their coordinates.
478
+ Dims for predictions data. Map variable names to their coordinates. Default behavior is to
479
+ infer dims if this is not provided
370
480
  num_chains: int
371
481
  Number of chains used for sampling. Ignored if posterior is present.
372
482
  """
@@ -382,5 +492,6 @@ def from_numpyro(
382
492
  coords=coords,
383
493
  dims=dims,
384
494
  pred_dims=pred_dims,
495
+ extra_event_dims=extra_event_dims,
385
496
  num_chains=num_chains,
386
497
  ).to_inference_data()
@@ -277,7 +277,7 @@ def _extract_arviz_dict_from_inference_data(
277
277
 
278
278
 
279
279
  def _convert_arviz_dict_to_pyjags_dict(
280
- samples: tp.Mapping[str, np.ndarray]
280
+ samples: tp.Mapping[str, np.ndarray],
281
281
  ) -> tp.Mapping[str, np.ndarray]:
282
282
  """
283
283
  Convert and ArviZ dictionary to a PyJAGS dictionary.
@@ -4,7 +4,7 @@ from ..data import convert_to_dataset
4
4
  from ..labels import BaseLabeller
5
5
  from ..sel_utils import xarray_var_iter
6
6
  from ..rcparams import rcParams
7
- from ..utils import _var_names
7
+ from ..utils import _var_names, get_coords
8
8
  from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
9
9
 
10
10
 
@@ -14,6 +14,7 @@ def plot_autocorr(
14
14
  filter_vars=None,
15
15
  max_lag=None,
16
16
  combined=False,
17
+ coords=None,
17
18
  grid=None,
18
19
  figsize=None,
19
20
  textsize=None,
@@ -42,6 +43,8 @@ def plot_autocorr(
42
43
  interpret `var_names` as substrings of the real variables names. If "regex",
43
44
  interpret `var_names` as regular expressions on the real variables names. See
44
45
  :ref:`this section <common_filter_vars>` for usage examples.
46
+ coords: mapping, optional
47
+ Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
45
48
  max_lag : int, optional
46
49
  Maximum lag to calculate autocorrelation. By Default, the plot displays the
47
50
  first 100 lag or the total number of draws, whichever is smaller.
@@ -124,11 +127,18 @@ def plot_autocorr(
124
127
  if max_lag is None:
125
128
  max_lag = min(100, data["draw"].shape[0])
126
129
 
130
+ if coords is None:
131
+ coords = {}
132
+
127
133
  if labeller is None:
128
134
  labeller = BaseLabeller()
129
135
 
130
136
  plotters = filter_plotters_list(
131
- list(xarray_var_iter(data, var_names, combined, dim_order=["chain", "draw"])),
137
+ list(
138
+ xarray_var_iter(
139
+ get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
140
+ )
141
+ ),
132
142
  "plot_autocorr",
133
143
  )
134
144
  rows, cols = default_grid(len(plotters), grid=grid)
@@ -21,9 +21,13 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
21
21
  plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
22
22
  plot_kwargs.setdefault("alpha", 0)
23
23
 
24
- fill_kwargs = {} if fill_kwargs is None else fill_kwargs
25
- fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
26
- fill_kwargs.setdefault("alpha", 0.5)
24
+ fill_kwargs = {} if fill_kwargs is None else fill_kwargs.copy()
25
+ # Convert matplotlib color to bokeh fill_color if needed
26
+ if "color" in fill_kwargs and "fill_color" not in fill_kwargs:
27
+ fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.pop("color"))
28
+ else:
29
+ fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.get("fill_color", color))
30
+ fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0.5))
27
31
 
28
32
  figsize, *_ = _scale_fig_size(figsize, None)
29
33
 
@@ -38,9 +42,6 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
38
42
  plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
39
43
  plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))
40
44
 
41
- fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
42
- fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))
43
-
44
45
  ax.patch(
45
46
  np.concatenate((x_data, x_data[::-1])),
46
47
  np.concatenate((y_data[:, 0], y_data[:, 1][::-1])),