arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/loopitplot.py DELETED
@@ -1,224 +0,0 @@
1
- """Plot LOO-PIT predictive checks of inference data."""
2
-
3
- import numpy as np
4
- from scipy import stats
5
-
6
- from ..labels import BaseLabeller
7
- from ..rcparams import rcParams
8
- from ..stats import loo_pit as _loo_pit
9
- from ..stats.density_utils import kde
10
- from .plot_utils import get_plotting_function
11
-
12
-
13
- def plot_loo_pit(
14
- idata=None,
15
- y=None,
16
- y_hat=None,
17
- log_weights=None,
18
- ecdf=False,
19
- ecdf_fill=True,
20
- n_unif=100,
21
- use_hdi=False,
22
- hdi_prob=None,
23
- figsize=None,
24
- textsize=None,
25
- labeller=None,
26
- color="C0",
27
- legend=True,
28
- ax=None,
29
- plot_kwargs=None,
30
- plot_unif_kwargs=None,
31
- hdi_kwargs=None,
32
- fill_kwargs=None,
33
- backend=None,
34
- backend_kwargs=None,
35
- show=None,
36
- ):
37
- """Plot Leave-One-Out (LOO) probability integral transformation (PIT) predictive checks.
38
-
39
- Parameters
40
- ----------
41
- idata : InferenceData
42
- :class:`arviz.InferenceData` object.
43
- y : array, DataArray or str
44
- Observed data. If str, ``idata`` must be present and contain the observed data group
45
- y_hat : array, DataArray or str
46
- Posterior predictive samples for ``y``. It must have the same shape as y plus an
47
- extra dimension at the end of size n_samples (chains and draws stacked). If str or
48
- None, ``idata`` must contain the posterior predictive group. If None, ``y_hat`` is taken
49
- equal to y, thus, y must be str too.
50
- log_weights : array or DataArray
51
- Smoothed log_weights. It must have the same shape as ``y_hat``
52
- ecdf : bool, optional
53
- Plot the difference between the LOO-PIT Empirical Cumulative Distribution Function
54
- (ECDF) and the uniform CDF instead of LOO-PIT kde.
55
- In this case, instead of overlaying uniform distributions, the beta ``hdi_prob``
56
- around the theoretical uniform CDF is shown. This approximation only holds
57
- for large S and ECDF values not very close to 0 nor 1. For more information, see
58
- `Vehtari et al. (2021)`, `Appendix G <https://avehtari.github.io/rhat_ess/rhat_ess.html>`_.
59
- ecdf_fill : bool, optional
60
- Use :meth:`matplotlib.axes.Axes.fill_between` to mark the area
61
- inside the credible interval. Otherwise, plot the
62
- border lines.
63
- n_unif : int, optional
64
- Number of datasets to simulate and overlay from the uniform distribution.
65
- use_hdi : bool, optional
66
- Compute expected hdi values instead of overlaying the sampled uniform distributions.
67
- hdi_prob : float, optional
68
- Probability for the highest density interval. Works with ``use_hdi=True`` or ``ecdf=True``.
69
- figsize : (float, float), optional
70
- If None, size is (8 + numvars, 8 + numvars)
71
- textsize : int, optional
72
- Text size for labels. If None it will be autoscaled based on ``figsize``.
73
- labeller : Labeller, optional
74
- Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
75
- Read the :ref:`label_guide` for more details and usage examples.
76
- color : str or array_like, optional
77
- Color of the LOO-PIT estimated pdf plot. If ``plot_unif_kwargs`` has no "color" key,
78
- a slightly lighter color than this argument will be used for the uniform kde lines.
79
- This will ensure that LOO-PIT kde and uniform kde have different default colors.
80
- legend : bool, optional
81
- Show the legend of the figure.
82
- ax : axes, optional
83
- Matplotlib axes or bokeh figures.
84
- plot_kwargs : dict, optional
85
- Additional keywords passed to :meth:`matplotlib.axes.Axes.plot`
86
- for LOO-PIT line (kde or ECDF)
87
- plot_unif_kwargs : dict, optional
88
- Additional keywords passed to :meth:`matplotlib.axes.Axes.plot` for
89
- overlaid uniform distributions or for beta credible interval
90
- lines if ``ecdf=True``
91
- hdi_kwargs : dict, optional
92
- Additional keywords passed to :meth:`matplotlib.axes.Axes.axhspan`
93
- fill_kwargs : dict, optional
94
- Additional kwargs passed to :meth:`matplotlib.axes.Axes.fill_between`
95
- backend : str, optional
96
- Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
97
- backend_kwargs : bool, optional
98
- These are kwargs specific to the backend being used, passed to
99
- :func:`matplotlib.pyplot.subplots` or
100
- :func:`bokeh.plotting.figure`. For additional documentation
101
- check the plotting method of the backend.
102
- show : bool, optional
103
- Call backend show function.
104
-
105
- Returns
106
- -------
107
- axes : matplotlib_axes or bokeh_figures
108
-
109
- See Also
110
- --------
111
- plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
112
- loo_pit : Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
113
-
114
- References
115
- ----------
116
- * Gabry et al. (2017) see https://arxiv.org/abs/1709.01449
117
- * https://mc-stan.org/bayesplot/reference/PPC-loo.html
118
- * Gelman et al. BDA (2014) Section 6.3
119
-
120
- Examples
121
- --------
122
- Plot LOO-PIT predictive checks overlaying the KDE of the LOO-PIT values to several
123
- realizations of uniform variable sampling with the same number of observations.
124
-
125
- .. plot::
126
- :context: close-figs
127
-
128
- >>> import arviz as az
129
- >>> idata = az.load_arviz_data("radon")
130
- >>> az.plot_loo_pit(idata=idata, y="y")
131
-
132
- Fill the area containing the 94% highest density interval of the difference between uniform
133
- variables empirical CDF and the real uniform CDF. A LOO-PIT ECDF clearly outside of these
134
- theoretical boundaries indicates that the observations and the posterior predictive
135
- samples do not follow the same distribution.
136
-
137
- .. plot::
138
- :context: close-figs
139
-
140
- >>> az.plot_loo_pit(idata=idata, y="y", ecdf=True)
141
-
142
- """
143
- if ecdf and use_hdi:
144
- raise ValueError("use_hdi is incompatible with ecdf plot")
145
-
146
- if labeller is None:
147
- labeller = BaseLabeller()
148
-
149
- loo_pit = _loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
150
- loo_pit = loo_pit.flatten() if isinstance(loo_pit, np.ndarray) else loo_pit.values.flatten()
151
-
152
- loo_pit_ecdf = None
153
- unif_ecdf = None
154
- p975 = None
155
- p025 = None
156
- loo_pit_kde = None
157
- hdi_odds = None
158
- unif = None
159
- x_vals = None
160
-
161
- if hdi_prob is None:
162
- hdi_prob = rcParams["stats.ci_prob"]
163
- elif not 1 >= hdi_prob > 0:
164
- raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
165
-
166
- if ecdf:
167
- loo_pit.sort()
168
- n_data_points = loo_pit.size
169
- loo_pit_ecdf = np.arange(n_data_points) / n_data_points
170
- # ideal unnormalized ECDF of uniform distribution with n_data_points points
171
- # it is used indistinctively as x or p(u<x) because for u~U(0,1) they are equal
172
- unif_ecdf = np.arange(n_data_points + 1)
173
- p975 = stats.beta.ppf(0.5 + hdi_prob / 2, unif_ecdf + 1, n_data_points - unif_ecdf + 1)
174
- p025 = stats.beta.ppf(0.5 - hdi_prob / 2, unif_ecdf + 1, n_data_points - unif_ecdf + 1)
175
- unif_ecdf = unif_ecdf / n_data_points
176
- else:
177
- x_vals, loo_pit_kde = kde(loo_pit)
178
-
179
- unif = np.random.uniform(size=(n_unif, loo_pit.size))
180
- if use_hdi:
181
- n_obs = loo_pit.size
182
- hdi_ = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
183
- hdi_odds = (hdi_ / (1 - hdi_), (1 - hdi_) / hdi_)
184
-
185
- loo_pit_kwargs = dict(
186
- ax=ax,
187
- figsize=figsize,
188
- ecdf=ecdf,
189
- loo_pit=loo_pit,
190
- loo_pit_ecdf=loo_pit_ecdf,
191
- unif_ecdf=unif_ecdf,
192
- p975=p975,
193
- p025=p025,
194
- fill_kwargs=fill_kwargs,
195
- ecdf_fill=ecdf_fill,
196
- use_hdi=use_hdi,
197
- x_vals=x_vals,
198
- hdi_kwargs=hdi_kwargs,
199
- hdi_odds=hdi_odds,
200
- n_unif=n_unif,
201
- unif=unif,
202
- plot_unif_kwargs=plot_unif_kwargs,
203
- loo_pit_kde=loo_pit_kde,
204
- textsize=textsize,
205
- labeller=labeller,
206
- color=color,
207
- legend=legend,
208
- y_hat=y_hat,
209
- y=y,
210
- hdi_prob=hdi_prob,
211
- plot_kwargs=plot_kwargs,
212
- backend_kwargs=backend_kwargs,
213
- show=show,
214
- )
215
-
216
- if backend is None:
217
- backend = rcParams["plot.backend"]
218
- backend = backend.lower()
219
-
220
- # TODO: Add backend kwargs
221
- plot = get_plotting_function("plot_loo_pit", "loopitplot", backend)
222
- axes = plot(**loo_pit_kwargs)
223
-
224
- return axes
arviz/plots/mcseplot.py DELETED
@@ -1,194 +0,0 @@
1
- """Plot quantile MC standard error."""
2
-
3
- import numpy as np
4
- import xarray as xr
5
-
6
- from ..data import convert_to_dataset
7
- from ..labels import BaseLabeller
8
- from ..sel_utils import xarray_var_iter
9
- from ..stats import mcse
10
- from ..rcparams import rcParams
11
- from ..utils import _var_names, get_coords
12
- from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
13
-
14
-
15
- def plot_mcse(
16
- idata,
17
- var_names=None,
18
- filter_vars=None,
19
- coords=None,
20
- errorbar=False,
21
- grid=None,
22
- figsize=None,
23
- textsize=None,
24
- extra_methods=False,
25
- rug=False,
26
- rug_kind="diverging",
27
- n_points=20,
28
- labeller=None,
29
- ax=None,
30
- rug_kwargs=None,
31
- extra_kwargs=None,
32
- text_kwargs=None,
33
- backend=None,
34
- backend_kwargs=None,
35
- show=None,
36
- **kwargs
37
- ):
38
- """Plot quantile or local Monte Carlo Standard Error.
39
-
40
- Parameters
41
- ----------
42
- idata : obj
43
- Any object that can be converted to an :class:`arviz.InferenceData` object
44
- Refer to documentation of :func:`arviz.convert_to_dataset` for details
45
- var_names : list of variable names, optional
46
- Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
47
- them from the plot.
48
- filter_vars : {None, "like", "regex"}, optional, default=None
49
- If `None` (default), interpret var_names as the real variables names. If "like",
50
- interpret var_names as substrings of the real variables names. If "regex",
51
- interpret var_names as regular expressions on the real variables names. A la
52
- `pandas.filter`.
53
- coords : dict, optional
54
- Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
55
- errorbar : bool, optional
56
- Plot quantile value +/- mcse instead of plotting mcse.
57
- grid : tuple
58
- Number of rows and columns. Defaults to None, the rows and columns are
59
- automatically inferred.
60
- figsize : (float, float), optional
61
- Figure size. If None it will be defined automatically.
62
- textsize : float, optional
63
- Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
64
- on figsize.
65
- extra_methods : bool, optional
66
- Plot mean and sd MCSE as horizontal lines. Only taken into account when
67
- ``errorbar=False``.
68
- rug : bool
69
- Plot rug plot of values diverging or that reached the max tree depth.
70
- rug_kind : bool
71
- Variable in sample stats to use as rug mask. Must be a boolean variable.
72
- n_points : int
73
- Number of points for which to plot their quantile/local ess or number of subsets
74
- in the evolution plot.
75
- labeller : Labeller, optional
76
- Class providing the method `make_label_vert` to generate the labels in the plot titles.
77
- Read the :ref:`label_guide` for more details and usage examples.
78
- ax : 2D array-like of matplotlib_axes or bokeh_figures, optional
79
- A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
80
- its own array of plot areas (and return it).
81
- rug_kwargs : dict
82
- kwargs passed to rug plot in
83
- :meth:`mpl:matplotlib.axes.Axes.plot` or :class:`bokeh:bokeh.models.glyphs.Scatter`.
84
- extra_kwargs : dict, optional
85
- kwargs passed as extra method lines in
86
- :meth:`mpl:matplotlib.axes.Axes.axhline` or :class:`bokeh:bokeh.models.Span`
87
- text_kwargs : dict, optional
88
- kwargs passed to :meth:`mpl:matplotlib.axes.Axes.annotate` for extra methods lines labels.
89
- It accepts the additional key ``x`` to set ``xy=(text_kwargs["x"], mcse)``.
90
- text_kwargs are ignored for the bokeh plotting backend.
91
- backend : str, optional
92
- Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
93
- backend_kwargs : bool, optional
94
- These are kwargs specific to the backend being passed to
95
- :func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
96
- show: bool, optional
97
- Call backend show function.
98
- **kwargs
99
- Passed as-is to :meth:`mpl:matplotlib.axes.Axes.hist` or
100
- :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib depending on the value of `kind`.
101
-
102
- Returns
103
- -------
104
- axes : matplotlib axes or bokeh figures
105
-
106
- See Also
107
- --------
108
- :func:`arviz.mcse`: Calculate Markov Chain Standard Error statistic.
109
-
110
- References
111
- ----------
112
- .. [1] Vehtari et al. (2021). Rank-normalization, folding, and
113
- localization: An improved Rhat for assessing convergence of
114
- MCMC. Bayesian analysis, 16(2):667-718.
115
-
116
- Examples
117
- --------
118
- Plot quantile Monte Carlo Standard Error.
119
-
120
- .. plot::
121
- :context: close-figs
122
-
123
- >>> import arviz as az
124
- >>> idata = az.load_arviz_data("centered_eight")
125
- >>> coords = {"school": ["Deerfield", "Lawrenceville"]}
126
- >>> az.plot_mcse(
127
- ... idata, var_names=["mu", "theta"], coords=coords
128
- ... )
129
-
130
- """
131
- mean_mcse = None
132
- sd_mcse = None
133
-
134
- if coords is None:
135
- coords = {}
136
- if "chain" in coords or "draw" in coords:
137
- raise ValueError("chain and draw are invalid coordinates for this kind of plot")
138
-
139
- if labeller is None:
140
- labeller = BaseLabeller()
141
-
142
- data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
143
- var_names = _var_names(var_names, data, filter_vars)
144
-
145
- probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points)
146
- mcse_dataset = xr.concat(
147
- [mcse(data, var_names=var_names, method="quantile", prob=p) for p in probs], dim="mcse_dim"
148
- )
149
-
150
- plotters = filter_plotters_list(
151
- list(xarray_var_iter(mcse_dataset, var_names=var_names, skip_dims={"mcse_dim"})),
152
- "plot_mcse",
153
- )
154
- length_plotters = len(plotters)
155
- rows, cols = default_grid(length_plotters, grid=grid)
156
-
157
- if extra_methods:
158
- mean_mcse = mcse(data, var_names=var_names, method="mean")
159
- sd_mcse = mcse(data, var_names=var_names, method="sd")
160
-
161
- mcse_kwargs = dict(
162
- ax=ax,
163
- plotters=plotters,
164
- length_plotters=length_plotters,
165
- rows=rows,
166
- cols=cols,
167
- figsize=figsize,
168
- errorbar=errorbar,
169
- rug=rug,
170
- data=data,
171
- probs=probs,
172
- kwargs=kwargs,
173
- extra_methods=extra_methods,
174
- mean_mcse=mean_mcse,
175
- sd_mcse=sd_mcse,
176
- textsize=textsize,
177
- labeller=labeller,
178
- text_kwargs=text_kwargs,
179
- rug_kwargs=rug_kwargs,
180
- extra_kwargs=extra_kwargs,
181
- idata=idata,
182
- rug_kind=rug_kind,
183
- backend_kwargs=backend_kwargs,
184
- show=show,
185
- )
186
-
187
- if backend is None:
188
- backend = rcParams["plot.backend"]
189
- backend = backend.lower()
190
-
191
- # TODO: Add backend kwargs
192
- plot = get_plotting_function("plot_mcse", "mcseplot", backend)
193
- ax = plot(**mcse_kwargs)
194
- return ax
arviz/plots/pairplot.py DELETED
@@ -1,281 +0,0 @@
1
- """Plot a scatter, kde and/or hexbin of sampled parameters."""
2
-
3
- import warnings
4
- from typing import List, Optional, Union
5
-
6
- import numpy as np
7
-
8
- from ..data import convert_to_dataset
9
- from ..labels import BaseLabeller
10
- from ..sel_utils import xarray_to_ndarray, xarray_var_iter
11
- from .plot_utils import get_plotting_function
12
- from ..rcparams import rcParams
13
- from ..utils import _var_names, get_coords
14
-
15
-
16
- def plot_pair(
17
- data,
18
- group="posterior",
19
- var_names: Optional[List[str]] = None,
20
- filter_vars: Optional[str] = None,
21
- combine_dims=None,
22
- coords=None,
23
- marginals=False,
24
- figsize=None,
25
- textsize=None,
26
- kind: Union[str, List[str]] = "scatter",
27
- gridsize="auto",
28
- divergences=False,
29
- colorbar=False,
30
- labeller=None,
31
- ax=None,
32
- divergences_kwargs=None,
33
- scatter_kwargs=None,
34
- kde_kwargs=None,
35
- hexbin_kwargs=None,
36
- backend=None,
37
- backend_kwargs=None,
38
- marginal_kwargs=None,
39
- point_estimate=None,
40
- point_estimate_kwargs=None,
41
- point_estimate_marker_kwargs=None,
42
- reference_values=None,
43
- reference_values_kwargs=None,
44
- show=None,
45
- ):
46
- """
47
- Plot a scatter, kde and/or hexbin matrix with (optional) marginals on the diagonal.
48
-
49
- Parameters
50
- ----------
51
- data: obj
52
- Any object that can be converted to an :class:`arviz.InferenceData` object.
53
- Refer to documentation of :func:`arviz.convert_to_dataset` for details
54
- group: str, optional
55
- Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
56
- var_names: list of variable names, optional
57
- Variables to be plotted, if None all variable are plotted. Prefix the
58
- variables by ``~`` when you want to exclude them from the plot.
59
- filter_vars: {None, "like", "regex"}, optional, default=None
60
- If `None` (default), interpret var_names as the real variables names. If "like",
61
- interpret var_names as substrings of the real variables names. If "regex",
62
- interpret var_names as regular expressions on the real variables names. A la
63
- ``pandas.filter``.
64
- combine_dims : set_like of str, optional
65
- List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
66
- See the :ref:`this section <common_combine_dims>` for usage examples.
67
- coords: mapping, optional
68
- Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`.
69
- marginals: bool, optional
70
- If True pairplot will include marginal distributions for every variable
71
- figsize: figure size tuple
72
- If None, size is (8 + numvars, 8 + numvars)
73
- textsize: int
74
- Text size for labels. If None it will be autoscaled based on ``figsize``.
75
- kind : str or List[str]
76
- Type of plot to display (scatter, kde and/or hexbin)
77
- gridsize: int or (int, int), optional
78
- Only works for ``kind=hexbin``. The number of hexagons in the x-direction.
79
- The corresponding number of hexagons in the y-direction is chosen
80
- such that the hexagons are approximately regular. Alternatively, gridsize
81
- can be a tuple with two elements specifying the number of hexagons
82
- in the x-direction and the y-direction.
83
- divergences: Boolean
84
- If True divergences will be plotted in a different color, only if group is either 'prior'
85
- or 'posterior'.
86
- colorbar: bool
87
- If True a colorbar will be included as part of the plot (Defaults to False).
88
- Only works when ``kind=hexbin``
89
- labeller : labeller instance, optional
90
- Class providing the method ``make_label_vert`` to generate the labels in the plot.
91
- Read the :ref:`label_guide` for more details and usage examples.
92
- ax: axes, optional
93
- Matplotlib axes or bokeh figures.
94
- divergences_kwargs: dicts, optional
95
- Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter` for divergences
96
- scatter_kwargs:
97
- Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter` when using scatter kind
98
- kde_kwargs: dict, optional
99
- Additional keywords passed to :func:`arviz.plot_kde` when using kde kind
100
- hexbin_kwargs: dict, optional
101
- Additional keywords passed to :meth:`matplotlib.axes.Axes.hexbin` when
102
- using hexbin kind
103
- backend: str, optional
104
- Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
105
- backend_kwargs: bool, optional
106
- These are kwargs specific to the backend being used, passed to
107
- :func:`matplotlib.pyplot.subplots` or
108
- :func:`bokeh.plotting.figure`.
109
- marginal_kwargs: dict, optional
110
- Additional keywords passed to :func:`arviz.plot_dist`, modifying the
111
- marginal distributions plotted in the diagonal.
112
- point_estimate: str, optional
113
- Select point estimate from 'mean', 'mode' or 'median'. The point estimate will be
114
- plotted using a scatter marker and vertical/horizontal lines.
115
- point_estimate_kwargs: dict, optional
116
- Additional keywords passed to :meth:`matplotlib.axes.Axes.axvline`,
117
- :meth:`matplotlib.axes.Axes.axhline` (matplotlib) or
118
- :class:`bokeh:bokeh.models.Span` (bokeh)
119
- point_estimate_marker_kwargs: dict, optional
120
- Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter`
121
- or :meth:`bokeh:bokeh.plotting.Figure.square` in point
122
- estimate plot. Not available in bokeh
123
- reference_values: dict, optional
124
- Reference values for the plotted variables. The Reference values will be plotted
125
- using a scatter marker
126
- reference_values_kwargs: dict, optional
127
- Additional keywords passed to :meth:`matplotlib.axes.Axes.plot` or
128
- :meth:`bokeh:bokeh.plotting.Figure.circle` in reference values plot
129
- show: bool, optional
130
- Call backend show function.
131
-
132
- Returns
133
- -------
134
- axes: matplotlib axes or bokeh figures
135
-
136
- Examples
137
- --------
138
- KDE Pair Plot
139
-
140
- .. plot::
141
- :context: close-figs
142
-
143
- >>> import arviz as az
144
- >>> centered = az.load_arviz_data('centered_eight')
145
- >>> coords = {'school': ['Choate', 'Deerfield']}
146
- >>> az.plot_pair(centered,
147
- >>> var_names=['theta', 'mu', 'tau'],
148
- >>> kind='kde',
149
- >>> coords=coords,
150
- >>> divergences=True,
151
- >>> textsize=18)
152
-
153
- Hexbin pair plot
154
-
155
- .. plot::
156
- :context: close-figs
157
-
158
- >>> az.plot_pair(centered,
159
- >>> var_names=['theta', 'mu'],
160
- >>> coords=coords,
161
- >>> textsize=18,
162
- >>> kind='hexbin')
163
-
164
- Pair plot showing divergences and select variables with regular expressions
165
-
166
- .. plot::
167
- :context: close-figs
168
-
169
- >>> az.plot_pair(centered,
170
- ... var_names=['^t', 'mu'],
171
- ... filter_vars="regex",
172
- ... coords=coords,
173
- ... divergences=True,
174
- ... textsize=18)
175
- """
176
- valid_kinds = ["scatter", "kde", "hexbin"]
177
- kind_boolean: Union[bool, List[bool]]
178
- if isinstance(kind, str):
179
- kind_boolean = kind in valid_kinds
180
- else:
181
- kind_boolean = [kind[i] in valid_kinds for i in range(len(kind))]
182
- if not np.all(kind_boolean):
183
- raise ValueError(f"Plot type {kind} not recognized. Plot type must be in {valid_kinds}")
184
-
185
- if coords is None:
186
- coords = {}
187
-
188
- if labeller is None:
189
- labeller = BaseLabeller()
190
-
191
- # Get posterior draws and combine chains
192
- dataset = convert_to_dataset(data, group=group)
193
- var_names = _var_names(var_names, dataset, filter_vars)
194
- plotters = list(
195
- xarray_var_iter(
196
- get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197
- )
198
- )
199
- flat_var_names = []
200
- flat_ref_slices = []
201
- flat_var_labels = []
202
- for var_name, sel, isel, _ in plotters:
203
- dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
204
- flat_var_names.append(var_name)
205
- flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
206
- flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
207
-
208
- divergent_data = None
209
- diverging_mask = None
210
-
211
- # Assigning divergence group based on group param
212
- if group == "posterior":
213
- divergent_group = "sample_stats"
214
- elif group == "prior":
215
- divergent_group = "sample_stats_prior"
216
- else:
217
- divergences = False
218
-
219
- # Get diverging draws and combine chains
220
- if divergences:
221
- if hasattr(data, divergent_group) and hasattr(getattr(data, divergent_group), "diverging"):
222
- divergent_data = convert_to_dataset(data, group=divergent_group)
223
- _, diverging_mask = xarray_to_ndarray(
224
- divergent_data, var_names=("diverging",), combined=True
225
- )
226
- diverging_mask = np.squeeze(diverging_mask)
227
- else:
228
- divergences = False
229
- warnings.warn(
230
- "Divergences data not found, plotting without divergences. "
231
- "Make sure the sample method provides divergences data and "
232
- "that it is present in the `diverging` field of `sample_stats` "
233
- "or `sample_stats_prior` or set divergences=False",
234
- UserWarning,
235
- )
236
-
237
- if gridsize == "auto":
238
- gridsize = int(dataset.sizes["draw"] ** 0.35)
239
-
240
- numvars = len(flat_var_names)
241
-
242
- if numvars < 2:
243
- raise ValueError("Number of variables to be plotted must be 2 or greater.")
244
-
245
- pairplot_kwargs = dict(
246
- ax=ax,
247
- plotters=plotters,
248
- numvars=numvars,
249
- figsize=figsize,
250
- textsize=textsize,
251
- kind=kind,
252
- scatter_kwargs=scatter_kwargs,
253
- kde_kwargs=kde_kwargs,
254
- hexbin_kwargs=hexbin_kwargs,
255
- gridsize=gridsize,
256
- colorbar=colorbar,
257
- divergences=divergences,
258
- diverging_mask=diverging_mask,
259
- divergences_kwargs=divergences_kwargs,
260
- flat_var_names=flat_var_names,
261
- flat_ref_slices=flat_ref_slices,
262
- flat_var_labels=flat_var_labels,
263
- backend_kwargs=backend_kwargs,
264
- marginal_kwargs=marginal_kwargs,
265
- show=show,
266
- marginals=marginals,
267
- point_estimate=point_estimate,
268
- point_estimate_kwargs=point_estimate_kwargs,
269
- point_estimate_marker_kwargs=point_estimate_marker_kwargs,
270
- reference_values=reference_values,
271
- reference_values_kwargs=reference_values_kwargs,
272
- )
273
-
274
- if backend is None:
275
- backend = rcParams["plot.backend"]
276
- backend = backend.lower()
277
-
278
- # TODO: Add backend kwargs
279
- plot = get_plotting_function("plot_pair", "pairplot", backend)
280
- ax = plot(**pairplot_kwargs)
281
- return ax