arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/loopitplot.py
DELETED
|
@@ -1,224 +0,0 @@
|
|
|
1
|
-
"""Plot LOO-PIT predictive checks of inference data."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from scipy import stats
|
|
5
|
-
|
|
6
|
-
from ..labels import BaseLabeller
|
|
7
|
-
from ..rcparams import rcParams
|
|
8
|
-
from ..stats import loo_pit as _loo_pit
|
|
9
|
-
from ..stats.density_utils import kde
|
|
10
|
-
from .plot_utils import get_plotting_function
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def plot_loo_pit(
|
|
14
|
-
idata=None,
|
|
15
|
-
y=None,
|
|
16
|
-
y_hat=None,
|
|
17
|
-
log_weights=None,
|
|
18
|
-
ecdf=False,
|
|
19
|
-
ecdf_fill=True,
|
|
20
|
-
n_unif=100,
|
|
21
|
-
use_hdi=False,
|
|
22
|
-
hdi_prob=None,
|
|
23
|
-
figsize=None,
|
|
24
|
-
textsize=None,
|
|
25
|
-
labeller=None,
|
|
26
|
-
color="C0",
|
|
27
|
-
legend=True,
|
|
28
|
-
ax=None,
|
|
29
|
-
plot_kwargs=None,
|
|
30
|
-
plot_unif_kwargs=None,
|
|
31
|
-
hdi_kwargs=None,
|
|
32
|
-
fill_kwargs=None,
|
|
33
|
-
backend=None,
|
|
34
|
-
backend_kwargs=None,
|
|
35
|
-
show=None,
|
|
36
|
-
):
|
|
37
|
-
"""Plot Leave-One-Out (LOO) probability integral transformation (PIT) predictive checks.
|
|
38
|
-
|
|
39
|
-
Parameters
|
|
40
|
-
----------
|
|
41
|
-
idata : InferenceData
|
|
42
|
-
:class:`arviz.InferenceData` object.
|
|
43
|
-
y : array, DataArray or str
|
|
44
|
-
Observed data. If str, ``idata`` must be present and contain the observed data group
|
|
45
|
-
y_hat : array, DataArray or str
|
|
46
|
-
Posterior predictive samples for ``y``. It must have the same shape as y plus an
|
|
47
|
-
extra dimension at the end of size n_samples (chains and draws stacked). If str or
|
|
48
|
-
None, ``idata`` must contain the posterior predictive group. If None, ``y_hat`` is taken
|
|
49
|
-
equal to y, thus, y must be str too.
|
|
50
|
-
log_weights : array or DataArray
|
|
51
|
-
Smoothed log_weights. It must have the same shape as ``y_hat``
|
|
52
|
-
ecdf : bool, optional
|
|
53
|
-
Plot the difference between the LOO-PIT Empirical Cumulative Distribution Function
|
|
54
|
-
(ECDF) and the uniform CDF instead of LOO-PIT kde.
|
|
55
|
-
In this case, instead of overlaying uniform distributions, the beta ``hdi_prob``
|
|
56
|
-
around the theoretical uniform CDF is shown. This approximation only holds
|
|
57
|
-
for large S and ECDF values not very close to 0 nor 1. For more information, see
|
|
58
|
-
`Vehtari et al. (2021)`, `Appendix G <https://avehtari.github.io/rhat_ess/rhat_ess.html>`_.
|
|
59
|
-
ecdf_fill : bool, optional
|
|
60
|
-
Use :meth:`matplotlib.axes.Axes.fill_between` to mark the area
|
|
61
|
-
inside the credible interval. Otherwise, plot the
|
|
62
|
-
border lines.
|
|
63
|
-
n_unif : int, optional
|
|
64
|
-
Number of datasets to simulate and overlay from the uniform distribution.
|
|
65
|
-
use_hdi : bool, optional
|
|
66
|
-
Compute expected hdi values instead of overlaying the sampled uniform distributions.
|
|
67
|
-
hdi_prob : float, optional
|
|
68
|
-
Probability for the highest density interval. Works with ``use_hdi=True`` or ``ecdf=True``.
|
|
69
|
-
figsize : (float, float), optional
|
|
70
|
-
If None, size is (8 + numvars, 8 + numvars)
|
|
71
|
-
textsize : int, optional
|
|
72
|
-
Text size for labels. If None it will be autoscaled based on ``figsize``.
|
|
73
|
-
labeller : Labeller, optional
|
|
74
|
-
Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
|
|
75
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
76
|
-
color : str or array_like, optional
|
|
77
|
-
Color of the LOO-PIT estimated pdf plot. If ``plot_unif_kwargs`` has no "color" key,
|
|
78
|
-
a slightly lighter color than this argument will be used for the uniform kde lines.
|
|
79
|
-
This will ensure that LOO-PIT kde and uniform kde have different default colors.
|
|
80
|
-
legend : bool, optional
|
|
81
|
-
Show the legend of the figure.
|
|
82
|
-
ax : axes, optional
|
|
83
|
-
Matplotlib axes or bokeh figures.
|
|
84
|
-
plot_kwargs : dict, optional
|
|
85
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.plot`
|
|
86
|
-
for LOO-PIT line (kde or ECDF)
|
|
87
|
-
plot_unif_kwargs : dict, optional
|
|
88
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.plot` for
|
|
89
|
-
overlaid uniform distributions or for beta credible interval
|
|
90
|
-
lines if ``ecdf=True``
|
|
91
|
-
hdi_kwargs : dict, optional
|
|
92
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.axhspan`
|
|
93
|
-
fill_kwargs : dict, optional
|
|
94
|
-
Additional kwargs passed to :meth:`matplotlib.axes.Axes.fill_between`
|
|
95
|
-
backend : str, optional
|
|
96
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
97
|
-
backend_kwargs : bool, optional
|
|
98
|
-
These are kwargs specific to the backend being used, passed to
|
|
99
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
100
|
-
:func:`bokeh.plotting.figure`. For additional documentation
|
|
101
|
-
check the plotting method of the backend.
|
|
102
|
-
show : bool, optional
|
|
103
|
-
Call backend show function.
|
|
104
|
-
|
|
105
|
-
Returns
|
|
106
|
-
-------
|
|
107
|
-
axes : matplotlib_axes or bokeh_figures
|
|
108
|
-
|
|
109
|
-
See Also
|
|
110
|
-
--------
|
|
111
|
-
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
|
|
112
|
-
loo_pit : Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
|
|
113
|
-
|
|
114
|
-
References
|
|
115
|
-
----------
|
|
116
|
-
* Gabry et al. (2017) see https://arxiv.org/abs/1709.01449
|
|
117
|
-
* https://mc-stan.org/bayesplot/reference/PPC-loo.html
|
|
118
|
-
* Gelman et al. BDA (2014) Section 6.3
|
|
119
|
-
|
|
120
|
-
Examples
|
|
121
|
-
--------
|
|
122
|
-
Plot LOO-PIT predictive checks overlaying the KDE of the LOO-PIT values to several
|
|
123
|
-
realizations of uniform variable sampling with the same number of observations.
|
|
124
|
-
|
|
125
|
-
.. plot::
|
|
126
|
-
:context: close-figs
|
|
127
|
-
|
|
128
|
-
>>> import arviz as az
|
|
129
|
-
>>> idata = az.load_arviz_data("radon")
|
|
130
|
-
>>> az.plot_loo_pit(idata=idata, y="y")
|
|
131
|
-
|
|
132
|
-
Fill the area containing the 94% highest density interval of the difference between uniform
|
|
133
|
-
variables empirical CDF and the real uniform CDF. A LOO-PIT ECDF clearly outside of these
|
|
134
|
-
theoretical boundaries indicates that the observations and the posterior predictive
|
|
135
|
-
samples do not follow the same distribution.
|
|
136
|
-
|
|
137
|
-
.. plot::
|
|
138
|
-
:context: close-figs
|
|
139
|
-
|
|
140
|
-
>>> az.plot_loo_pit(idata=idata, y="y", ecdf=True)
|
|
141
|
-
|
|
142
|
-
"""
|
|
143
|
-
if ecdf and use_hdi:
|
|
144
|
-
raise ValueError("use_hdi is incompatible with ecdf plot")
|
|
145
|
-
|
|
146
|
-
if labeller is None:
|
|
147
|
-
labeller = BaseLabeller()
|
|
148
|
-
|
|
149
|
-
loo_pit = _loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
|
|
150
|
-
loo_pit = loo_pit.flatten() if isinstance(loo_pit, np.ndarray) else loo_pit.values.flatten()
|
|
151
|
-
|
|
152
|
-
loo_pit_ecdf = None
|
|
153
|
-
unif_ecdf = None
|
|
154
|
-
p975 = None
|
|
155
|
-
p025 = None
|
|
156
|
-
loo_pit_kde = None
|
|
157
|
-
hdi_odds = None
|
|
158
|
-
unif = None
|
|
159
|
-
x_vals = None
|
|
160
|
-
|
|
161
|
-
if hdi_prob is None:
|
|
162
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
163
|
-
elif not 1 >= hdi_prob > 0:
|
|
164
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
165
|
-
|
|
166
|
-
if ecdf:
|
|
167
|
-
loo_pit.sort()
|
|
168
|
-
n_data_points = loo_pit.size
|
|
169
|
-
loo_pit_ecdf = np.arange(n_data_points) / n_data_points
|
|
170
|
-
# ideal unnormalized ECDF of uniform distribution with n_data_points points
|
|
171
|
-
# it is used indistinctively as x or p(u<x) because for u~U(0,1) they are equal
|
|
172
|
-
unif_ecdf = np.arange(n_data_points + 1)
|
|
173
|
-
p975 = stats.beta.ppf(0.5 + hdi_prob / 2, unif_ecdf + 1, n_data_points - unif_ecdf + 1)
|
|
174
|
-
p025 = stats.beta.ppf(0.5 - hdi_prob / 2, unif_ecdf + 1, n_data_points - unif_ecdf + 1)
|
|
175
|
-
unif_ecdf = unif_ecdf / n_data_points
|
|
176
|
-
else:
|
|
177
|
-
x_vals, loo_pit_kde = kde(loo_pit)
|
|
178
|
-
|
|
179
|
-
unif = np.random.uniform(size=(n_unif, loo_pit.size))
|
|
180
|
-
if use_hdi:
|
|
181
|
-
n_obs = loo_pit.size
|
|
182
|
-
hdi_ = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
|
|
183
|
-
hdi_odds = (hdi_ / (1 - hdi_), (1 - hdi_) / hdi_)
|
|
184
|
-
|
|
185
|
-
loo_pit_kwargs = dict(
|
|
186
|
-
ax=ax,
|
|
187
|
-
figsize=figsize,
|
|
188
|
-
ecdf=ecdf,
|
|
189
|
-
loo_pit=loo_pit,
|
|
190
|
-
loo_pit_ecdf=loo_pit_ecdf,
|
|
191
|
-
unif_ecdf=unif_ecdf,
|
|
192
|
-
p975=p975,
|
|
193
|
-
p025=p025,
|
|
194
|
-
fill_kwargs=fill_kwargs,
|
|
195
|
-
ecdf_fill=ecdf_fill,
|
|
196
|
-
use_hdi=use_hdi,
|
|
197
|
-
x_vals=x_vals,
|
|
198
|
-
hdi_kwargs=hdi_kwargs,
|
|
199
|
-
hdi_odds=hdi_odds,
|
|
200
|
-
n_unif=n_unif,
|
|
201
|
-
unif=unif,
|
|
202
|
-
plot_unif_kwargs=plot_unif_kwargs,
|
|
203
|
-
loo_pit_kde=loo_pit_kde,
|
|
204
|
-
textsize=textsize,
|
|
205
|
-
labeller=labeller,
|
|
206
|
-
color=color,
|
|
207
|
-
legend=legend,
|
|
208
|
-
y_hat=y_hat,
|
|
209
|
-
y=y,
|
|
210
|
-
hdi_prob=hdi_prob,
|
|
211
|
-
plot_kwargs=plot_kwargs,
|
|
212
|
-
backend_kwargs=backend_kwargs,
|
|
213
|
-
show=show,
|
|
214
|
-
)
|
|
215
|
-
|
|
216
|
-
if backend is None:
|
|
217
|
-
backend = rcParams["plot.backend"]
|
|
218
|
-
backend = backend.lower()
|
|
219
|
-
|
|
220
|
-
# TODO: Add backend kwargs
|
|
221
|
-
plot = get_plotting_function("plot_loo_pit", "loopitplot", backend)
|
|
222
|
-
axes = plot(**loo_pit_kwargs)
|
|
223
|
-
|
|
224
|
-
return axes
|
arviz/plots/mcseplot.py
DELETED
|
@@ -1,194 +0,0 @@
|
|
|
1
|
-
"""Plot quantile MC standard error."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import xarray as xr
|
|
5
|
-
|
|
6
|
-
from ..data import convert_to_dataset
|
|
7
|
-
from ..labels import BaseLabeller
|
|
8
|
-
from ..sel_utils import xarray_var_iter
|
|
9
|
-
from ..stats import mcse
|
|
10
|
-
from ..rcparams import rcParams
|
|
11
|
-
from ..utils import _var_names, get_coords
|
|
12
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def plot_mcse(
|
|
16
|
-
idata,
|
|
17
|
-
var_names=None,
|
|
18
|
-
filter_vars=None,
|
|
19
|
-
coords=None,
|
|
20
|
-
errorbar=False,
|
|
21
|
-
grid=None,
|
|
22
|
-
figsize=None,
|
|
23
|
-
textsize=None,
|
|
24
|
-
extra_methods=False,
|
|
25
|
-
rug=False,
|
|
26
|
-
rug_kind="diverging",
|
|
27
|
-
n_points=20,
|
|
28
|
-
labeller=None,
|
|
29
|
-
ax=None,
|
|
30
|
-
rug_kwargs=None,
|
|
31
|
-
extra_kwargs=None,
|
|
32
|
-
text_kwargs=None,
|
|
33
|
-
backend=None,
|
|
34
|
-
backend_kwargs=None,
|
|
35
|
-
show=None,
|
|
36
|
-
**kwargs
|
|
37
|
-
):
|
|
38
|
-
"""Plot quantile or local Monte Carlo Standard Error.
|
|
39
|
-
|
|
40
|
-
Parameters
|
|
41
|
-
----------
|
|
42
|
-
idata : obj
|
|
43
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
44
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
45
|
-
var_names : list of variable names, optional
|
|
46
|
-
Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
|
|
47
|
-
them from the plot.
|
|
48
|
-
filter_vars : {None, "like", "regex"}, optional, default=None
|
|
49
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
50
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
51
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
52
|
-
`pandas.filter`.
|
|
53
|
-
coords : dict, optional
|
|
54
|
-
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
55
|
-
errorbar : bool, optional
|
|
56
|
-
Plot quantile value +/- mcse instead of plotting mcse.
|
|
57
|
-
grid : tuple
|
|
58
|
-
Number of rows and columns. Defaults to None, the rows and columns are
|
|
59
|
-
automatically inferred.
|
|
60
|
-
figsize : (float, float), optional
|
|
61
|
-
Figure size. If None it will be defined automatically.
|
|
62
|
-
textsize : float, optional
|
|
63
|
-
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
|
|
64
|
-
on figsize.
|
|
65
|
-
extra_methods : bool, optional
|
|
66
|
-
Plot mean and sd MCSE as horizontal lines. Only taken into account when
|
|
67
|
-
``errorbar=False``.
|
|
68
|
-
rug : bool
|
|
69
|
-
Plot rug plot of values diverging or that reached the max tree depth.
|
|
70
|
-
rug_kind : bool
|
|
71
|
-
Variable in sample stats to use as rug mask. Must be a boolean variable.
|
|
72
|
-
n_points : int
|
|
73
|
-
Number of points for which to plot their quantile/local ess or number of subsets
|
|
74
|
-
in the evolution plot.
|
|
75
|
-
labeller : Labeller, optional
|
|
76
|
-
Class providing the method `make_label_vert` to generate the labels in the plot titles.
|
|
77
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
78
|
-
ax : 2D array-like of matplotlib_axes or bokeh_figures, optional
|
|
79
|
-
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
|
|
80
|
-
its own array of plot areas (and return it).
|
|
81
|
-
rug_kwargs : dict
|
|
82
|
-
kwargs passed to rug plot in
|
|
83
|
-
:meth:`mpl:matplotlib.axes.Axes.plot` or :class:`bokeh:bokeh.models.glyphs.Scatter`.
|
|
84
|
-
extra_kwargs : dict, optional
|
|
85
|
-
kwargs passed as extra method lines in
|
|
86
|
-
:meth:`mpl:matplotlib.axes.Axes.axhline` or :class:`bokeh:bokeh.models.Span`
|
|
87
|
-
text_kwargs : dict, optional
|
|
88
|
-
kwargs passed to :meth:`mpl:matplotlib.axes.Axes.annotate` for extra methods lines labels.
|
|
89
|
-
It accepts the additional key ``x`` to set ``xy=(text_kwargs["x"], mcse)``.
|
|
90
|
-
text_kwargs are ignored for the bokeh plotting backend.
|
|
91
|
-
backend : str, optional
|
|
92
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
93
|
-
backend_kwargs : bool, optional
|
|
94
|
-
These are kwargs specific to the backend being passed to
|
|
95
|
-
:func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
|
|
96
|
-
show: bool, optional
|
|
97
|
-
Call backend show function.
|
|
98
|
-
**kwargs
|
|
99
|
-
Passed as-is to :meth:`mpl:matplotlib.axes.Axes.hist` or
|
|
100
|
-
:meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib depending on the value of `kind`.
|
|
101
|
-
|
|
102
|
-
Returns
|
|
103
|
-
-------
|
|
104
|
-
axes : matplotlib axes or bokeh figures
|
|
105
|
-
|
|
106
|
-
See Also
|
|
107
|
-
--------
|
|
108
|
-
:func:`arviz.mcse`: Calculate Markov Chain Standard Error statistic.
|
|
109
|
-
|
|
110
|
-
References
|
|
111
|
-
----------
|
|
112
|
-
.. [1] Vehtari et al. (2021). Rank-normalization, folding, and
|
|
113
|
-
localization: An improved Rhat for assessing convergence of
|
|
114
|
-
MCMC. Bayesian analysis, 16(2):667-718.
|
|
115
|
-
|
|
116
|
-
Examples
|
|
117
|
-
--------
|
|
118
|
-
Plot quantile Monte Carlo Standard Error.
|
|
119
|
-
|
|
120
|
-
.. plot::
|
|
121
|
-
:context: close-figs
|
|
122
|
-
|
|
123
|
-
>>> import arviz as az
|
|
124
|
-
>>> idata = az.load_arviz_data("centered_eight")
|
|
125
|
-
>>> coords = {"school": ["Deerfield", "Lawrenceville"]}
|
|
126
|
-
>>> az.plot_mcse(
|
|
127
|
-
... idata, var_names=["mu", "theta"], coords=coords
|
|
128
|
-
... )
|
|
129
|
-
|
|
130
|
-
"""
|
|
131
|
-
mean_mcse = None
|
|
132
|
-
sd_mcse = None
|
|
133
|
-
|
|
134
|
-
if coords is None:
|
|
135
|
-
coords = {}
|
|
136
|
-
if "chain" in coords or "draw" in coords:
|
|
137
|
-
raise ValueError("chain and draw are invalid coordinates for this kind of plot")
|
|
138
|
-
|
|
139
|
-
if labeller is None:
|
|
140
|
-
labeller = BaseLabeller()
|
|
141
|
-
|
|
142
|
-
data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
|
|
143
|
-
var_names = _var_names(var_names, data, filter_vars)
|
|
144
|
-
|
|
145
|
-
probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points)
|
|
146
|
-
mcse_dataset = xr.concat(
|
|
147
|
-
[mcse(data, var_names=var_names, method="quantile", prob=p) for p in probs], dim="mcse_dim"
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
plotters = filter_plotters_list(
|
|
151
|
-
list(xarray_var_iter(mcse_dataset, var_names=var_names, skip_dims={"mcse_dim"})),
|
|
152
|
-
"plot_mcse",
|
|
153
|
-
)
|
|
154
|
-
length_plotters = len(plotters)
|
|
155
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
156
|
-
|
|
157
|
-
if extra_methods:
|
|
158
|
-
mean_mcse = mcse(data, var_names=var_names, method="mean")
|
|
159
|
-
sd_mcse = mcse(data, var_names=var_names, method="sd")
|
|
160
|
-
|
|
161
|
-
mcse_kwargs = dict(
|
|
162
|
-
ax=ax,
|
|
163
|
-
plotters=plotters,
|
|
164
|
-
length_plotters=length_plotters,
|
|
165
|
-
rows=rows,
|
|
166
|
-
cols=cols,
|
|
167
|
-
figsize=figsize,
|
|
168
|
-
errorbar=errorbar,
|
|
169
|
-
rug=rug,
|
|
170
|
-
data=data,
|
|
171
|
-
probs=probs,
|
|
172
|
-
kwargs=kwargs,
|
|
173
|
-
extra_methods=extra_methods,
|
|
174
|
-
mean_mcse=mean_mcse,
|
|
175
|
-
sd_mcse=sd_mcse,
|
|
176
|
-
textsize=textsize,
|
|
177
|
-
labeller=labeller,
|
|
178
|
-
text_kwargs=text_kwargs,
|
|
179
|
-
rug_kwargs=rug_kwargs,
|
|
180
|
-
extra_kwargs=extra_kwargs,
|
|
181
|
-
idata=idata,
|
|
182
|
-
rug_kind=rug_kind,
|
|
183
|
-
backend_kwargs=backend_kwargs,
|
|
184
|
-
show=show,
|
|
185
|
-
)
|
|
186
|
-
|
|
187
|
-
if backend is None:
|
|
188
|
-
backend = rcParams["plot.backend"]
|
|
189
|
-
backend = backend.lower()
|
|
190
|
-
|
|
191
|
-
# TODO: Add backend kwargs
|
|
192
|
-
plot = get_plotting_function("plot_mcse", "mcseplot", backend)
|
|
193
|
-
ax = plot(**mcse_kwargs)
|
|
194
|
-
return ax
|
arviz/plots/pairplot.py
DELETED
|
@@ -1,281 +0,0 @@
|
|
|
1
|
-
"""Plot a scatter, kde and/or hexbin of sampled parameters."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from typing import List, Optional, Union
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
|
|
8
|
-
from ..data import convert_to_dataset
|
|
9
|
-
from ..labels import BaseLabeller
|
|
10
|
-
from ..sel_utils import xarray_to_ndarray, xarray_var_iter
|
|
11
|
-
from .plot_utils import get_plotting_function
|
|
12
|
-
from ..rcparams import rcParams
|
|
13
|
-
from ..utils import _var_names, get_coords
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def plot_pair(
|
|
17
|
-
data,
|
|
18
|
-
group="posterior",
|
|
19
|
-
var_names: Optional[List[str]] = None,
|
|
20
|
-
filter_vars: Optional[str] = None,
|
|
21
|
-
combine_dims=None,
|
|
22
|
-
coords=None,
|
|
23
|
-
marginals=False,
|
|
24
|
-
figsize=None,
|
|
25
|
-
textsize=None,
|
|
26
|
-
kind: Union[str, List[str]] = "scatter",
|
|
27
|
-
gridsize="auto",
|
|
28
|
-
divergences=False,
|
|
29
|
-
colorbar=False,
|
|
30
|
-
labeller=None,
|
|
31
|
-
ax=None,
|
|
32
|
-
divergences_kwargs=None,
|
|
33
|
-
scatter_kwargs=None,
|
|
34
|
-
kde_kwargs=None,
|
|
35
|
-
hexbin_kwargs=None,
|
|
36
|
-
backend=None,
|
|
37
|
-
backend_kwargs=None,
|
|
38
|
-
marginal_kwargs=None,
|
|
39
|
-
point_estimate=None,
|
|
40
|
-
point_estimate_kwargs=None,
|
|
41
|
-
point_estimate_marker_kwargs=None,
|
|
42
|
-
reference_values=None,
|
|
43
|
-
reference_values_kwargs=None,
|
|
44
|
-
show=None,
|
|
45
|
-
):
|
|
46
|
-
"""
|
|
47
|
-
Plot a scatter, kde and/or hexbin matrix with (optional) marginals on the diagonal.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
data: obj
|
|
52
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
53
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
54
|
-
group: str, optional
|
|
55
|
-
Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
|
|
56
|
-
var_names: list of variable names, optional
|
|
57
|
-
Variables to be plotted, if None all variable are plotted. Prefix the
|
|
58
|
-
variables by ``~`` when you want to exclude them from the plot.
|
|
59
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
60
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
61
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
62
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
63
|
-
``pandas.filter``.
|
|
64
|
-
combine_dims : set_like of str, optional
|
|
65
|
-
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
|
|
66
|
-
See the :ref:`this section <common_combine_dims>` for usage examples.
|
|
67
|
-
coords: mapping, optional
|
|
68
|
-
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`.
|
|
69
|
-
marginals: bool, optional
|
|
70
|
-
If True pairplot will include marginal distributions for every variable
|
|
71
|
-
figsize: figure size tuple
|
|
72
|
-
If None, size is (8 + numvars, 8 + numvars)
|
|
73
|
-
textsize: int
|
|
74
|
-
Text size for labels. If None it will be autoscaled based on ``figsize``.
|
|
75
|
-
kind : str or List[str]
|
|
76
|
-
Type of plot to display (scatter, kde and/or hexbin)
|
|
77
|
-
gridsize: int or (int, int), optional
|
|
78
|
-
Only works for ``kind=hexbin``. The number of hexagons in the x-direction.
|
|
79
|
-
The corresponding number of hexagons in the y-direction is chosen
|
|
80
|
-
such that the hexagons are approximately regular. Alternatively, gridsize
|
|
81
|
-
can be a tuple with two elements specifying the number of hexagons
|
|
82
|
-
in the x-direction and the y-direction.
|
|
83
|
-
divergences: Boolean
|
|
84
|
-
If True divergences will be plotted in a different color, only if group is either 'prior'
|
|
85
|
-
or 'posterior'.
|
|
86
|
-
colorbar: bool
|
|
87
|
-
If True a colorbar will be included as part of the plot (Defaults to False).
|
|
88
|
-
Only works when ``kind=hexbin``
|
|
89
|
-
labeller : labeller instance, optional
|
|
90
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot.
|
|
91
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
92
|
-
ax: axes, optional
|
|
93
|
-
Matplotlib axes or bokeh figures.
|
|
94
|
-
divergences_kwargs: dicts, optional
|
|
95
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter` for divergences
|
|
96
|
-
scatter_kwargs:
|
|
97
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter` when using scatter kind
|
|
98
|
-
kde_kwargs: dict, optional
|
|
99
|
-
Additional keywords passed to :func:`arviz.plot_kde` when using kde kind
|
|
100
|
-
hexbin_kwargs: dict, optional
|
|
101
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.hexbin` when
|
|
102
|
-
using hexbin kind
|
|
103
|
-
backend: str, optional
|
|
104
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
105
|
-
backend_kwargs: bool, optional
|
|
106
|
-
These are kwargs specific to the backend being used, passed to
|
|
107
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
108
|
-
:func:`bokeh.plotting.figure`.
|
|
109
|
-
marginal_kwargs: dict, optional
|
|
110
|
-
Additional keywords passed to :func:`arviz.plot_dist`, modifying the
|
|
111
|
-
marginal distributions plotted in the diagonal.
|
|
112
|
-
point_estimate: str, optional
|
|
113
|
-
Select point estimate from 'mean', 'mode' or 'median'. The point estimate will be
|
|
114
|
-
plotted using a scatter marker and vertical/horizontal lines.
|
|
115
|
-
point_estimate_kwargs: dict, optional
|
|
116
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.axvline`,
|
|
117
|
-
:meth:`matplotlib.axes.Axes.axhline` (matplotlib) or
|
|
118
|
-
:class:`bokeh:bokeh.models.Span` (bokeh)
|
|
119
|
-
point_estimate_marker_kwargs: dict, optional
|
|
120
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.scatter`
|
|
121
|
-
or :meth:`bokeh:bokeh.plotting.Figure.square` in point
|
|
122
|
-
estimate plot. Not available in bokeh
|
|
123
|
-
reference_values: dict, optional
|
|
124
|
-
Reference values for the plotted variables. The Reference values will be plotted
|
|
125
|
-
using a scatter marker
|
|
126
|
-
reference_values_kwargs: dict, optional
|
|
127
|
-
Additional keywords passed to :meth:`matplotlib.axes.Axes.plot` or
|
|
128
|
-
:meth:`bokeh:bokeh.plotting.Figure.circle` in reference values plot
|
|
129
|
-
show: bool, optional
|
|
130
|
-
Call backend show function.
|
|
131
|
-
|
|
132
|
-
Returns
|
|
133
|
-
-------
|
|
134
|
-
axes: matplotlib axes or bokeh figures
|
|
135
|
-
|
|
136
|
-
Examples
|
|
137
|
-
--------
|
|
138
|
-
KDE Pair Plot
|
|
139
|
-
|
|
140
|
-
.. plot::
|
|
141
|
-
:context: close-figs
|
|
142
|
-
|
|
143
|
-
>>> import arviz as az
|
|
144
|
-
>>> centered = az.load_arviz_data('centered_eight')
|
|
145
|
-
>>> coords = {'school': ['Choate', 'Deerfield']}
|
|
146
|
-
>>> az.plot_pair(centered,
|
|
147
|
-
>>> var_names=['theta', 'mu', 'tau'],
|
|
148
|
-
>>> kind='kde',
|
|
149
|
-
>>> coords=coords,
|
|
150
|
-
>>> divergences=True,
|
|
151
|
-
>>> textsize=18)
|
|
152
|
-
|
|
153
|
-
Hexbin pair plot
|
|
154
|
-
|
|
155
|
-
.. plot::
|
|
156
|
-
:context: close-figs
|
|
157
|
-
|
|
158
|
-
>>> az.plot_pair(centered,
|
|
159
|
-
>>> var_names=['theta', 'mu'],
|
|
160
|
-
>>> coords=coords,
|
|
161
|
-
>>> textsize=18,
|
|
162
|
-
>>> kind='hexbin')
|
|
163
|
-
|
|
164
|
-
Pair plot showing divergences and select variables with regular expressions
|
|
165
|
-
|
|
166
|
-
.. plot::
|
|
167
|
-
:context: close-figs
|
|
168
|
-
|
|
169
|
-
>>> az.plot_pair(centered,
|
|
170
|
-
... var_names=['^t', 'mu'],
|
|
171
|
-
... filter_vars="regex",
|
|
172
|
-
... coords=coords,
|
|
173
|
-
... divergences=True,
|
|
174
|
-
... textsize=18)
|
|
175
|
-
"""
|
|
176
|
-
valid_kinds = ["scatter", "kde", "hexbin"]
|
|
177
|
-
kind_boolean: Union[bool, List[bool]]
|
|
178
|
-
if isinstance(kind, str):
|
|
179
|
-
kind_boolean = kind in valid_kinds
|
|
180
|
-
else:
|
|
181
|
-
kind_boolean = [kind[i] in valid_kinds for i in range(len(kind))]
|
|
182
|
-
if not np.all(kind_boolean):
|
|
183
|
-
raise ValueError(f"Plot type {kind} not recognized. Plot type must be in {valid_kinds}")
|
|
184
|
-
|
|
185
|
-
if coords is None:
|
|
186
|
-
coords = {}
|
|
187
|
-
|
|
188
|
-
if labeller is None:
|
|
189
|
-
labeller = BaseLabeller()
|
|
190
|
-
|
|
191
|
-
# Get posterior draws and combine chains
|
|
192
|
-
dataset = convert_to_dataset(data, group=group)
|
|
193
|
-
var_names = _var_names(var_names, dataset, filter_vars)
|
|
194
|
-
plotters = list(
|
|
195
|
-
xarray_var_iter(
|
|
196
|
-
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
|
|
197
|
-
)
|
|
198
|
-
)
|
|
199
|
-
flat_var_names = []
|
|
200
|
-
flat_ref_slices = []
|
|
201
|
-
flat_var_labels = []
|
|
202
|
-
for var_name, sel, isel, _ in plotters:
|
|
203
|
-
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
|
|
204
|
-
flat_var_names.append(var_name)
|
|
205
|
-
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
|
|
206
|
-
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
|
|
207
|
-
|
|
208
|
-
divergent_data = None
|
|
209
|
-
diverging_mask = None
|
|
210
|
-
|
|
211
|
-
# Assigning divergence group based on group param
|
|
212
|
-
if group == "posterior":
|
|
213
|
-
divergent_group = "sample_stats"
|
|
214
|
-
elif group == "prior":
|
|
215
|
-
divergent_group = "sample_stats_prior"
|
|
216
|
-
else:
|
|
217
|
-
divergences = False
|
|
218
|
-
|
|
219
|
-
# Get diverging draws and combine chains
|
|
220
|
-
if divergences:
|
|
221
|
-
if hasattr(data, divergent_group) and hasattr(getattr(data, divergent_group), "diverging"):
|
|
222
|
-
divergent_data = convert_to_dataset(data, group=divergent_group)
|
|
223
|
-
_, diverging_mask = xarray_to_ndarray(
|
|
224
|
-
divergent_data, var_names=("diverging",), combined=True
|
|
225
|
-
)
|
|
226
|
-
diverging_mask = np.squeeze(diverging_mask)
|
|
227
|
-
else:
|
|
228
|
-
divergences = False
|
|
229
|
-
warnings.warn(
|
|
230
|
-
"Divergences data not found, plotting without divergences. "
|
|
231
|
-
"Make sure the sample method provides divergences data and "
|
|
232
|
-
"that it is present in the `diverging` field of `sample_stats` "
|
|
233
|
-
"or `sample_stats_prior` or set divergences=False",
|
|
234
|
-
UserWarning,
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
if gridsize == "auto":
|
|
238
|
-
gridsize = int(dataset.sizes["draw"] ** 0.35)
|
|
239
|
-
|
|
240
|
-
numvars = len(flat_var_names)
|
|
241
|
-
|
|
242
|
-
if numvars < 2:
|
|
243
|
-
raise ValueError("Number of variables to be plotted must be 2 or greater.")
|
|
244
|
-
|
|
245
|
-
pairplot_kwargs = dict(
|
|
246
|
-
ax=ax,
|
|
247
|
-
plotters=plotters,
|
|
248
|
-
numvars=numvars,
|
|
249
|
-
figsize=figsize,
|
|
250
|
-
textsize=textsize,
|
|
251
|
-
kind=kind,
|
|
252
|
-
scatter_kwargs=scatter_kwargs,
|
|
253
|
-
kde_kwargs=kde_kwargs,
|
|
254
|
-
hexbin_kwargs=hexbin_kwargs,
|
|
255
|
-
gridsize=gridsize,
|
|
256
|
-
colorbar=colorbar,
|
|
257
|
-
divergences=divergences,
|
|
258
|
-
diverging_mask=diverging_mask,
|
|
259
|
-
divergences_kwargs=divergences_kwargs,
|
|
260
|
-
flat_var_names=flat_var_names,
|
|
261
|
-
flat_ref_slices=flat_ref_slices,
|
|
262
|
-
flat_var_labels=flat_var_labels,
|
|
263
|
-
backend_kwargs=backend_kwargs,
|
|
264
|
-
marginal_kwargs=marginal_kwargs,
|
|
265
|
-
show=show,
|
|
266
|
-
marginals=marginals,
|
|
267
|
-
point_estimate=point_estimate,
|
|
268
|
-
point_estimate_kwargs=point_estimate_kwargs,
|
|
269
|
-
point_estimate_marker_kwargs=point_estimate_marker_kwargs,
|
|
270
|
-
reference_values=reference_values,
|
|
271
|
-
reference_values_kwargs=reference_values_kwargs,
|
|
272
|
-
)
|
|
273
|
-
|
|
274
|
-
if backend is None:
|
|
275
|
-
backend = rcParams["plot.backend"]
|
|
276
|
-
backend = backend.lower()
|
|
277
|
-
|
|
278
|
-
# TODO: Add backend kwargs
|
|
279
|
-
plot = get_plotting_function("plot_pair", "pairplot", backend)
|
|
280
|
-
ax = plot(**pairplot_kwargs)
|
|
281
|
-
return ax
|