arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,148 +0,0 @@
|
|
|
1
|
-
"""Matplotlib Violinplot."""
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
from ....stats import hdi
|
|
7
|
-
from ....stats.density_utils import get_bins, histogram, kde
|
|
8
|
-
from ...plot_utils import _scale_fig_size
|
|
9
|
-
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def plot_violin(
|
|
13
|
-
ax,
|
|
14
|
-
plotters,
|
|
15
|
-
figsize,
|
|
16
|
-
rows,
|
|
17
|
-
cols,
|
|
18
|
-
sharex,
|
|
19
|
-
sharey,
|
|
20
|
-
shade_kwargs,
|
|
21
|
-
shade,
|
|
22
|
-
rug,
|
|
23
|
-
rug_kwargs,
|
|
24
|
-
side,
|
|
25
|
-
bw,
|
|
26
|
-
textsize,
|
|
27
|
-
labeller,
|
|
28
|
-
circular,
|
|
29
|
-
hdi_prob,
|
|
30
|
-
quartiles,
|
|
31
|
-
backend_kwargs,
|
|
32
|
-
show,
|
|
33
|
-
):
|
|
34
|
-
"""Matplotlib violin plot."""
|
|
35
|
-
if backend_kwargs is None:
|
|
36
|
-
backend_kwargs = {}
|
|
37
|
-
|
|
38
|
-
backend_kwargs = {
|
|
39
|
-
**backend_kwarg_defaults(),
|
|
40
|
-
**backend_kwargs,
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
(figsize, ax_labelsize, _, xt_labelsize, linewidth, _) = _scale_fig_size(
|
|
44
|
-
figsize, textsize, rows, cols
|
|
45
|
-
)
|
|
46
|
-
backend_kwargs.setdefault("figsize", figsize)
|
|
47
|
-
backend_kwargs.setdefault("sharex", sharex)
|
|
48
|
-
backend_kwargs.setdefault("sharey", sharey)
|
|
49
|
-
backend_kwargs.setdefault("squeeze", True)
|
|
50
|
-
|
|
51
|
-
shade_kwargs = matplotlib_kwarg_dealiaser(shade_kwargs, "hexbin")
|
|
52
|
-
rug_kwargs = matplotlib_kwarg_dealiaser(rug_kwargs, "plot")
|
|
53
|
-
rug_kwargs.setdefault("alpha", 0.1)
|
|
54
|
-
rug_kwargs.setdefault("marker", ".")
|
|
55
|
-
rug_kwargs.setdefault("linestyle", "")
|
|
56
|
-
|
|
57
|
-
if ax is None:
|
|
58
|
-
fig, ax = create_axes_grid(
|
|
59
|
-
len(plotters),
|
|
60
|
-
rows,
|
|
61
|
-
cols,
|
|
62
|
-
backend_kwargs=backend_kwargs,
|
|
63
|
-
)
|
|
64
|
-
fig.set_layout_engine("none")
|
|
65
|
-
fig.subplots_adjust(wspace=0)
|
|
66
|
-
|
|
67
|
-
ax = np.atleast_1d(ax)
|
|
68
|
-
|
|
69
|
-
current_col = 0
|
|
70
|
-
for (var_name, selection, isel, x), ax_ in zip(plotters, ax.flatten()):
|
|
71
|
-
val = x.flatten()
|
|
72
|
-
if val[0].dtype.kind == "i":
|
|
73
|
-
dens = cat_hist(val, rug, side, shade, ax_, **shade_kwargs)
|
|
74
|
-
else:
|
|
75
|
-
dens = _violinplot(val, rug, side, shade, bw, circular, ax_, **shade_kwargs)
|
|
76
|
-
|
|
77
|
-
if rug:
|
|
78
|
-
rug_x = -np.abs(np.random.normal(scale=max(dens) / 3.5, size=len(val)))
|
|
79
|
-
ax_.plot(rug_x, val, **rug_kwargs)
|
|
80
|
-
|
|
81
|
-
per = np.nanpercentile(val, [25, 75, 50])
|
|
82
|
-
hdi_probs = hdi(val, hdi_prob, multimodal=False, skipna=True)
|
|
83
|
-
|
|
84
|
-
if quartiles:
|
|
85
|
-
ax_.plot([0, 0], per[:2], lw=linewidth * 3, color="k", solid_capstyle="round")
|
|
86
|
-
ax_.plot([0, 0], hdi_probs, lw=linewidth, color="k", solid_capstyle="round")
|
|
87
|
-
ax_.plot(0, per[-1], "wo", ms=linewidth * 1.5)
|
|
88
|
-
|
|
89
|
-
ax_.set_title(labeller.make_label_vert(var_name, selection, isel), fontsize=ax_labelsize)
|
|
90
|
-
ax_.set_xticks([])
|
|
91
|
-
ax_.tick_params(labelsize=xt_labelsize)
|
|
92
|
-
ax_.grid(None, axis="x")
|
|
93
|
-
if current_col != 0:
|
|
94
|
-
ax_.spines["left"].set_visible(False)
|
|
95
|
-
ax_.yaxis.set_ticks_position("none")
|
|
96
|
-
current_col += 1
|
|
97
|
-
if current_col == cols:
|
|
98
|
-
current_col = 0
|
|
99
|
-
|
|
100
|
-
if backend_show(show):
|
|
101
|
-
plt.show()
|
|
102
|
-
|
|
103
|
-
return ax
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def _violinplot(val, rug, side, shade, bw, circular, ax, **shade_kwargs):
|
|
107
|
-
"""Auxiliary function to plot violinplots."""
|
|
108
|
-
if bw == "default":
|
|
109
|
-
bw = "taylor" if circular else "experimental"
|
|
110
|
-
x, density = kde(val, circular=circular, bw=bw)
|
|
111
|
-
|
|
112
|
-
if rug and side == "both":
|
|
113
|
-
side = "right"
|
|
114
|
-
|
|
115
|
-
if side == "left":
|
|
116
|
-
dens = -density
|
|
117
|
-
elif side == "right":
|
|
118
|
-
x = x[::-1]
|
|
119
|
-
dens = density[::-1]
|
|
120
|
-
elif side == "both":
|
|
121
|
-
x = np.concatenate([x, x[::-1]])
|
|
122
|
-
dens = np.concatenate([-density, density[::-1]])
|
|
123
|
-
|
|
124
|
-
ax.fill_betweenx(x, dens, alpha=shade, lw=0, **shade_kwargs)
|
|
125
|
-
return density
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def cat_hist(val, rug, side, shade, ax, **shade_kwargs):
|
|
129
|
-
"""Auxiliary function to plot discrete-violinplots."""
|
|
130
|
-
bins = get_bins(val)
|
|
131
|
-
_, binned_d, _ = histogram(val, bins=bins)
|
|
132
|
-
|
|
133
|
-
bin_edges = np.linspace(np.min(val), np.max(val), len(bins))
|
|
134
|
-
heights = np.diff(bin_edges)
|
|
135
|
-
centers = bin_edges[:-1] + heights.mean() / 2
|
|
136
|
-
|
|
137
|
-
if rug and side == "both":
|
|
138
|
-
side = "right"
|
|
139
|
-
|
|
140
|
-
if side == "right":
|
|
141
|
-
left = None
|
|
142
|
-
elif side == "left":
|
|
143
|
-
left = -binned_d
|
|
144
|
-
elif side == "both":
|
|
145
|
-
left = -0.5 * binned_d
|
|
146
|
-
|
|
147
|
-
ax.barh(centers, binned_d, height=heights, left=left, alpha=shade, **shade_kwargs)
|
|
148
|
-
return binned_d
|
arviz/plots/bfplot.py
DELETED
|
@@ -1,128 +0,0 @@
|
|
|
1
|
-
# Plotting and reporting Bayes Factor given idata, var name, prior distribution and reference value
|
|
2
|
-
# pylint: disable=unbalanced-tuple-unpacking
|
|
3
|
-
import logging
|
|
4
|
-
|
|
5
|
-
from ..data.utils import extract
|
|
6
|
-
from .plot_utils import get_plotting_function
|
|
7
|
-
from ..stats import bayes_factor
|
|
8
|
-
|
|
9
|
-
_log = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def plot_bf(
|
|
13
|
-
idata,
|
|
14
|
-
var_name,
|
|
15
|
-
prior=None,
|
|
16
|
-
ref_val=0,
|
|
17
|
-
colors=("C0", "C1"),
|
|
18
|
-
figsize=None,
|
|
19
|
-
textsize=None,
|
|
20
|
-
hist_kwargs=None,
|
|
21
|
-
plot_kwargs=None,
|
|
22
|
-
ax=None,
|
|
23
|
-
backend=None,
|
|
24
|
-
backend_kwargs=None,
|
|
25
|
-
show=None,
|
|
26
|
-
):
|
|
27
|
-
r"""Approximated Bayes Factor for comparing hypothesis of two nested models.
|
|
28
|
-
|
|
29
|
-
The Bayes factor is estimated by comparing a model (H1) against a model in which the
|
|
30
|
-
parameter of interest has been restricted to be a point-null (H0). This computation
|
|
31
|
-
assumes the models are nested and thus H0 is a special case of H1.
|
|
32
|
-
|
|
33
|
-
Notes
|
|
34
|
-
-----
|
|
35
|
-
The bayes Factor is approximated as the Savage-Dickey density ratio
|
|
36
|
-
algorithm presented in [1]_.
|
|
37
|
-
|
|
38
|
-
Parameters
|
|
39
|
-
----------
|
|
40
|
-
idata : InferenceData
|
|
41
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
42
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
43
|
-
var_name : str, optional
|
|
44
|
-
Name of variable we want to test.
|
|
45
|
-
prior : numpy.array, optional
|
|
46
|
-
In case we want to use different prior, for example for sensitivity analysis.
|
|
47
|
-
ref_val : int, default 0
|
|
48
|
-
Point-null for Bayes factor estimation.
|
|
49
|
-
colors : tuple, default ('C0', 'C1')
|
|
50
|
-
Tuple of valid Matplotlib colors. First element for the prior, second for the posterior.
|
|
51
|
-
figsize : (float, float), optional
|
|
52
|
-
Figure size. If `None` it will be defined automatically.
|
|
53
|
-
textsize : float, optional
|
|
54
|
-
Text size scaling factor for labels, titles and lines. If `None` it will be auto
|
|
55
|
-
scaled based on `figsize`.
|
|
56
|
-
plot_kwargs : dict, optional
|
|
57
|
-
Additional keywords passed to :func:`matplotlib.pyplot.plot`.
|
|
58
|
-
hist_kwargs : dict, optional
|
|
59
|
-
Additional keywords passed to :func:`arviz.plot_dist`. Only works for discrete variables.
|
|
60
|
-
ax : axes, optional
|
|
61
|
-
:class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
|
|
62
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
63
|
-
Select plotting backend.
|
|
64
|
-
backend_kwargs : dict, optional
|
|
65
|
-
These are kwargs specific to the backend being used, passed to
|
|
66
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
67
|
-
For additional documentation check the plotting method of the backend.
|
|
68
|
-
show : bool, optional
|
|
69
|
-
Call backend show function.
|
|
70
|
-
|
|
71
|
-
Returns
|
|
72
|
-
-------
|
|
73
|
-
dict : A dictionary with BF10 (Bayes Factor 10 (H1/H0 ratio), and BF01 (H0/H1 ratio).
|
|
74
|
-
axes : matplotlib_axes or bokeh_figure
|
|
75
|
-
|
|
76
|
-
References
|
|
77
|
-
----------
|
|
78
|
-
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
|
|
79
|
-
The case of computing Bayes factors for regression parameters.
|
|
80
|
-
|
|
81
|
-
Examples
|
|
82
|
-
--------
|
|
83
|
-
Moderate evidence indicating that the parameter "a" is different from zero.
|
|
84
|
-
|
|
85
|
-
.. plot::
|
|
86
|
-
:context: close-figs
|
|
87
|
-
|
|
88
|
-
>>> import numpy as np
|
|
89
|
-
>>> import arviz as az
|
|
90
|
-
>>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
|
|
91
|
-
... prior={"a":np.random.normal(0, 1, 5000)})
|
|
92
|
-
>>> az.plot_bf(idata, var_name="a", ref_val=0)
|
|
93
|
-
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
if prior is None:
|
|
97
|
-
prior = extract(idata, var_names=var_name, group="prior").values
|
|
98
|
-
|
|
99
|
-
bf, p_at_ref_val = bayes_factor(
|
|
100
|
-
idata, var_name, prior=prior, ref_val=ref_val, return_ref_vals=True
|
|
101
|
-
)
|
|
102
|
-
bf_10 = bf["BF10"]
|
|
103
|
-
bf_01 = bf["BF01"]
|
|
104
|
-
|
|
105
|
-
posterior = extract(idata, var_names=var_name)
|
|
106
|
-
|
|
107
|
-
bfplot_kwargs = dict(
|
|
108
|
-
ax=ax,
|
|
109
|
-
bf_10=bf_10.item(),
|
|
110
|
-
bf_01=bf_01.item(),
|
|
111
|
-
prior=prior,
|
|
112
|
-
posterior=posterior,
|
|
113
|
-
ref_val=ref_val,
|
|
114
|
-
prior_at_ref_val=p_at_ref_val["prior"],
|
|
115
|
-
posterior_at_ref_val=p_at_ref_val["posterior"],
|
|
116
|
-
var_name=var_name,
|
|
117
|
-
colors=colors,
|
|
118
|
-
figsize=figsize,
|
|
119
|
-
textsize=textsize,
|
|
120
|
-
hist_kwargs=hist_kwargs,
|
|
121
|
-
plot_kwargs=plot_kwargs,
|
|
122
|
-
backend_kwargs=backend_kwargs,
|
|
123
|
-
show=show,
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
plot = get_plotting_function("plot_bf", "bfplot", backend)
|
|
127
|
-
axes = plot(**bfplot_kwargs)
|
|
128
|
-
return {"BF10": bf_10, "BF01": bf_01}, axes
|
arviz/plots/bpvplot.py
DELETED
|
@@ -1,308 +0,0 @@
|
|
|
1
|
-
"""Bayesian p-value Posterior/Prior predictive plot."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
|
|
5
|
-
from ..labels import BaseLabeller
|
|
6
|
-
from ..rcparams import rcParams
|
|
7
|
-
from ..utils import _var_names
|
|
8
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
9
|
-
from ..sel_utils import xarray_var_iter
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def plot_bpv(
|
|
13
|
-
data,
|
|
14
|
-
kind="u_value",
|
|
15
|
-
t_stat="median",
|
|
16
|
-
bpv=True,
|
|
17
|
-
plot_mean=True,
|
|
18
|
-
reference="analytical",
|
|
19
|
-
smoothing=None,
|
|
20
|
-
mse=False,
|
|
21
|
-
n_ref=100,
|
|
22
|
-
hdi_prob=0.94,
|
|
23
|
-
color="C0",
|
|
24
|
-
grid=None,
|
|
25
|
-
figsize=None,
|
|
26
|
-
textsize=None,
|
|
27
|
-
labeller=None,
|
|
28
|
-
data_pairs=None,
|
|
29
|
-
var_names=None,
|
|
30
|
-
filter_vars=None,
|
|
31
|
-
coords=None,
|
|
32
|
-
flatten=None,
|
|
33
|
-
flatten_pp=None,
|
|
34
|
-
ax=None,
|
|
35
|
-
backend=None,
|
|
36
|
-
plot_ref_kwargs=None,
|
|
37
|
-
backend_kwargs=None,
|
|
38
|
-
group="posterior",
|
|
39
|
-
show=None,
|
|
40
|
-
):
|
|
41
|
-
r"""Plot Bayesian p-value for observed data and Posterior/Prior predictive.
|
|
42
|
-
|
|
43
|
-
Parameters
|
|
44
|
-
----------
|
|
45
|
-
data : InferenceData
|
|
46
|
-
:class:`arviz.InferenceData` object containing the observed and
|
|
47
|
-
posterior/prior predictive data.
|
|
48
|
-
kind : {"u_value", "p_value", "t_stat"}, default "u_value"
|
|
49
|
-
Specify the kind of plot:
|
|
50
|
-
|
|
51
|
-
* The ``kind="p_value"`` computes :math:`p := p(y* \leq y | y)`.
|
|
52
|
-
This is the probability of the data y being larger or equal than the predicted data y*.
|
|
53
|
-
The ideal value is 0.5 (half the predictions below and half above the data).
|
|
54
|
-
* The ``kind="u_value"`` argument computes :math:`p_i := p(y_i* \leq y_i | y)`.
|
|
55
|
-
i.e. like a p_value but per observation :math:`y_i`. This is also known as marginal
|
|
56
|
-
p_value. The ideal distribution is uniform. This is similar to the LOO-PIT
|
|
57
|
-
calculation/plot, the difference is than in LOO-pit plot we compute
|
|
58
|
-
:math:`pi = p(y_i* r \leq y_i | y_{-i} )`, where :math:`y_{-i}`,
|
|
59
|
-
is all other data except :math:`y_i`.
|
|
60
|
-
* The ``kind="t_stat"`` argument computes :math:`:= p(T(y)* \leq T(y) | y)`
|
|
61
|
-
where T is any test statistic. See ``t_stat`` argument below for details
|
|
62
|
-
of available options.
|
|
63
|
-
|
|
64
|
-
t_stat : str, float, or callable, default "median"
|
|
65
|
-
Test statistics to compute from the observations and predictive distributions.
|
|
66
|
-
Allowed strings are “mean”, “median” or “std”. Alternative a quantile can be passed
|
|
67
|
-
as a float (or str) in the interval (0, 1). Finally a user defined function is also
|
|
68
|
-
acepted, see examples section for details.
|
|
69
|
-
bpv : bool, default True
|
|
70
|
-
If True add the Bayesian p_value to the legend when ``kind = t_stat``.
|
|
71
|
-
plot_mean : bool, default True
|
|
72
|
-
Whether or not to plot the mean test statistic.
|
|
73
|
-
reference : {"analytical", "samples", None}, default "analytical"
|
|
74
|
-
How to compute the distributions used as reference for ``kind=u_values``
|
|
75
|
-
or ``kind=p_values``. Use `None` to not plot any reference.
|
|
76
|
-
smoothing : bool, optional
|
|
77
|
-
If True and the data has integer dtype, smooth the data before computing the p-values,
|
|
78
|
-
u-values or tstat. By default, True when `kind` is "u_value" and False otherwise.
|
|
79
|
-
mse : bool, default False
|
|
80
|
-
Show scaled mean square error between uniform distribution and marginal p_value
|
|
81
|
-
distribution.
|
|
82
|
-
n_ref : int, default 100
|
|
83
|
-
Number of reference distributions to sample when ``reference=samples``.
|
|
84
|
-
hdi_prob : float, optional
|
|
85
|
-
Probability for the highest density interval for the analytical reference distribution when
|
|
86
|
-
``kind=u_values``. Should be in the interval (0, 1]. Defaults to the
|
|
87
|
-
rcParam ``stats.ci_prob``. See :ref:`this section <common_hdi_prob>` for usage examples.
|
|
88
|
-
color : str, optional
|
|
89
|
-
Matplotlib color
|
|
90
|
-
grid : tuple, optional
|
|
91
|
-
Number of rows and columns. By default, the rows and columns are
|
|
92
|
-
automatically inferred. See :ref:`this section <common_grid>` for usage examples.
|
|
93
|
-
figsize : (float, float), optional
|
|
94
|
-
Figure size. If None it will be defined automatically.
|
|
95
|
-
textsize : float, optional
|
|
96
|
-
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
|
|
97
|
-
on `figsize`.
|
|
98
|
-
data_pairs : dict, optional
|
|
99
|
-
Dictionary containing relations between observed data and posterior/prior predictive data.
|
|
100
|
-
Dictionary structure:
|
|
101
|
-
|
|
102
|
-
- key = data var_name
|
|
103
|
-
- value = posterior/prior predictive var_name
|
|
104
|
-
|
|
105
|
-
For example, ``data_pairs = {'y' : 'y_hat'}``
|
|
106
|
-
If None, it will assume that the observed data and the posterior/prior
|
|
107
|
-
predictive data have the same variable name.
|
|
108
|
-
Labeller : Labeller, optional
|
|
109
|
-
Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
|
|
110
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
111
|
-
var_names : list of str, optional
|
|
112
|
-
Variables to be plotted. If `None` all variable are plotted. Prefix the variables by ``~``
|
|
113
|
-
when you want to exclude them from the plot. See the :ref:`this section <common_var_names>`
|
|
114
|
-
for usage examples. See :ref:`this section <common_var_names>` for usage examples.
|
|
115
|
-
filter_vars : {None, "like", "regex"}, default None
|
|
116
|
-
If `None` (default), interpret `var_names` as the real variables names. If "like",
|
|
117
|
-
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
118
|
-
interpret `var_names` as regular expressions on the real variables names. See
|
|
119
|
-
:ref:`this section <common_filter_vars>` for usage examples.
|
|
120
|
-
coords : dict, optional
|
|
121
|
-
Dictionary mapping dimensions to selected coordinates to be plotted.
|
|
122
|
-
Dimensions without a mapping specified will include all coordinates for
|
|
123
|
-
that dimension. Defaults to including all coordinates for all
|
|
124
|
-
dimensions if None. See :ref:`this section <common_coords>` for usage examples.
|
|
125
|
-
flatten : list, optional
|
|
126
|
-
List of dimensions to flatten in observed_data. Only flattens across the coordinates
|
|
127
|
-
specified in the coords argument. Defaults to flattening all of the dimensions.
|
|
128
|
-
flatten_pp : list, optional
|
|
129
|
-
List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
|
|
130
|
-
across the coordinates specified in the coords argument. Defaults to flattening all
|
|
131
|
-
of the dimensions. Dimensions should match flatten excluding dimensions for data_pairs
|
|
132
|
-
parameters. If `flatten` is defined and `flatten_pp` is None, then ``flatten_pp=flatten``.
|
|
133
|
-
legend : bool, default True
|
|
134
|
-
Add legend to figure.
|
|
135
|
-
ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
|
|
136
|
-
A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
|
|
137
|
-
its own array of plot areas (and return it).
|
|
138
|
-
backend : str, optional
|
|
139
|
-
Select plotting backend {"matplotlib", "bokeh"}. Default "matplotlib".
|
|
140
|
-
plot_ref_kwargs : dict, optional
|
|
141
|
-
Extra keyword arguments to control how reference is represented.
|
|
142
|
-
Passed to :meth:`matplotlib.axes.Axes.plot` or
|
|
143
|
-
:meth:`matplotlib.axes.Axes.axhspan` (when ``kind=u_value``
|
|
144
|
-
and ``reference=analytical``).
|
|
145
|
-
backend_kwargs : bool, optional
|
|
146
|
-
These are kwargs specific to the backend being used, passed to
|
|
147
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
148
|
-
For additional documentation check the plotting method of the backend.
|
|
149
|
-
group : {"posterior", "prior"}, default "posterior"
|
|
150
|
-
Specifies which InferenceData group should be plotted. If "posterior", then the values
|
|
151
|
-
in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then
|
|
152
|
-
the same comparison happens, but with the values in `prior_predictive` group.
|
|
153
|
-
show : bool, optional
|
|
154
|
-
Call backend show function.
|
|
155
|
-
|
|
156
|
-
Returns
|
|
157
|
-
-------
|
|
158
|
-
axes : 2D ndarray of matplotlib_axes or bokeh_figure
|
|
159
|
-
|
|
160
|
-
See Also
|
|
161
|
-
--------
|
|
162
|
-
plot_ppc : Plot for posterior/prior predictive checks.
|
|
163
|
-
plot_loo_pit : Plot Leave-One-Out probability integral transformation (PIT) predictive checks.
|
|
164
|
-
plot_dist_comparison : Plot to compare fitted and unfitted distributions.
|
|
165
|
-
|
|
166
|
-
References
|
|
167
|
-
----------
|
|
168
|
-
* Gelman et al. (2013) see http://www.stat.columbia.edu/~gelman/book/ pages 151-153 for details
|
|
169
|
-
|
|
170
|
-
Notes
|
|
171
|
-
-----
|
|
172
|
-
Discrete data is smoothed before computing either p-values or u-values using the
|
|
173
|
-
function :func:`~arviz.smooth_data` if the data is integer type
|
|
174
|
-
and the smoothing parameter is True.
|
|
175
|
-
|
|
176
|
-
Examples
|
|
177
|
-
--------
|
|
178
|
-
Plot Bayesian p_values.
|
|
179
|
-
|
|
180
|
-
.. plot::
|
|
181
|
-
:context: close-figs
|
|
182
|
-
|
|
183
|
-
>>> import arviz as az
|
|
184
|
-
>>> data = az.load_arviz_data("regression1d")
|
|
185
|
-
>>> az.plot_bpv(data, kind="p_value")
|
|
186
|
-
|
|
187
|
-
Plot custom test statistic comparison.
|
|
188
|
-
|
|
189
|
-
.. plot::
|
|
190
|
-
:context: close-figs
|
|
191
|
-
|
|
192
|
-
>>> import arviz as az
|
|
193
|
-
>>> data = az.load_arviz_data("regression1d")
|
|
194
|
-
>>> az.plot_bpv(data, kind="t_stat", t_stat=lambda x:np.percentile(x, q=50, axis=-1))
|
|
195
|
-
"""
|
|
196
|
-
if group not in ("posterior", "prior"):
|
|
197
|
-
raise TypeError("`group` argument must be either `posterior` or `prior`")
|
|
198
|
-
|
|
199
|
-
for groups in (f"{group}_predictive", "observed_data"):
|
|
200
|
-
if not hasattr(data, groups):
|
|
201
|
-
raise TypeError(f'`data` argument must have the group "{groups}"')
|
|
202
|
-
|
|
203
|
-
if kind.lower() not in ("t_stat", "u_value", "p_value"):
|
|
204
|
-
raise TypeError("`kind` argument must be either `t_stat`, `u_value`, or `p_value`")
|
|
205
|
-
|
|
206
|
-
if reference is not None and reference.lower() not in ("analytical", "samples"):
|
|
207
|
-
raise TypeError("`reference` argument must be either `analytical`, `samples`, or `None`")
|
|
208
|
-
|
|
209
|
-
if hdi_prob is None:
|
|
210
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
211
|
-
elif not 1 >= hdi_prob > 0:
|
|
212
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
213
|
-
|
|
214
|
-
if smoothing is None:
|
|
215
|
-
smoothing = kind.lower() == "u_value"
|
|
216
|
-
|
|
217
|
-
if data_pairs is None:
|
|
218
|
-
data_pairs = {}
|
|
219
|
-
|
|
220
|
-
if labeller is None:
|
|
221
|
-
labeller = BaseLabeller()
|
|
222
|
-
|
|
223
|
-
if backend is None:
|
|
224
|
-
backend = rcParams["plot.backend"]
|
|
225
|
-
backend = backend.lower()
|
|
226
|
-
|
|
227
|
-
observed = data.observed_data
|
|
228
|
-
|
|
229
|
-
if group == "posterior":
|
|
230
|
-
predictive_dataset = data.posterior_predictive
|
|
231
|
-
elif group == "prior":
|
|
232
|
-
predictive_dataset = data.prior_predictive
|
|
233
|
-
|
|
234
|
-
if var_names is None:
|
|
235
|
-
var_names = list(observed.data_vars)
|
|
236
|
-
var_names = _var_names(var_names, observed, filter_vars)
|
|
237
|
-
pp_var_names = [data_pairs.get(var, var) for var in var_names]
|
|
238
|
-
pp_var_names = _var_names(pp_var_names, predictive_dataset, filter_vars)
|
|
239
|
-
|
|
240
|
-
if flatten_pp is None:
|
|
241
|
-
if flatten is None:
|
|
242
|
-
flatten_pp = list(predictive_dataset.dims)
|
|
243
|
-
else:
|
|
244
|
-
flatten_pp = flatten
|
|
245
|
-
if flatten is None:
|
|
246
|
-
flatten = list(observed.dims)
|
|
247
|
-
|
|
248
|
-
if coords is None:
|
|
249
|
-
coords = {}
|
|
250
|
-
|
|
251
|
-
total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]
|
|
252
|
-
|
|
253
|
-
for key in coords.keys():
|
|
254
|
-
coords[key] = np.where(np.isin(observed[key], coords[key]))[0]
|
|
255
|
-
|
|
256
|
-
obs_plotters = filter_plotters_list(
|
|
257
|
-
list(
|
|
258
|
-
xarray_var_iter(
|
|
259
|
-
observed.isel(coords), skip_dims=set(flatten), var_names=var_names, combined=True
|
|
260
|
-
)
|
|
261
|
-
),
|
|
262
|
-
"plot_t_stats",
|
|
263
|
-
)
|
|
264
|
-
length_plotters = len(obs_plotters)
|
|
265
|
-
pp_plotters = [
|
|
266
|
-
tup
|
|
267
|
-
for _, tup in zip(
|
|
268
|
-
range(length_plotters),
|
|
269
|
-
xarray_var_iter(
|
|
270
|
-
predictive_dataset.isel(coords),
|
|
271
|
-
var_names=pp_var_names,
|
|
272
|
-
skip_dims=set(flatten_pp),
|
|
273
|
-
combined=True,
|
|
274
|
-
),
|
|
275
|
-
)
|
|
276
|
-
]
|
|
277
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
278
|
-
|
|
279
|
-
bpvplot_kwargs = dict(
|
|
280
|
-
ax=ax,
|
|
281
|
-
length_plotters=length_plotters,
|
|
282
|
-
rows=rows,
|
|
283
|
-
cols=cols,
|
|
284
|
-
obs_plotters=obs_plotters,
|
|
285
|
-
pp_plotters=pp_plotters,
|
|
286
|
-
total_pp_samples=total_pp_samples,
|
|
287
|
-
kind=kind,
|
|
288
|
-
bpv=bpv,
|
|
289
|
-
t_stat=t_stat,
|
|
290
|
-
reference=reference,
|
|
291
|
-
mse=mse,
|
|
292
|
-
n_ref=n_ref,
|
|
293
|
-
hdi_prob=hdi_prob,
|
|
294
|
-
plot_mean=plot_mean,
|
|
295
|
-
color=color,
|
|
296
|
-
figsize=figsize,
|
|
297
|
-
textsize=textsize,
|
|
298
|
-
labeller=labeller,
|
|
299
|
-
plot_ref_kwargs=plot_ref_kwargs,
|
|
300
|
-
backend_kwargs=backend_kwargs,
|
|
301
|
-
show=show,
|
|
302
|
-
smoothing=smoothing,
|
|
303
|
-
)
|
|
304
|
-
|
|
305
|
-
# TODO: Add backend kwargs
|
|
306
|
-
plot = get_plotting_function("plot_bpv", "bpvplot", backend)
|
|
307
|
-
axes = plot(**bpvplot_kwargs)
|
|
308
|
-
return axes
|