arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/compareplot.py
DELETED
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
"""Summary plot for model comparison."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
|
|
5
|
-
from ..labels import BaseLabeller
|
|
6
|
-
from ..rcparams import rcParams
|
|
7
|
-
from .plot_utils import get_plotting_function
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def plot_compare(
|
|
11
|
-
comp_df,
|
|
12
|
-
insample_dev=False,
|
|
13
|
-
plot_standard_error=True,
|
|
14
|
-
plot_ic_diff=False,
|
|
15
|
-
order_by_rank=True,
|
|
16
|
-
legend=False,
|
|
17
|
-
title=True,
|
|
18
|
-
figsize=None,
|
|
19
|
-
textsize=None,
|
|
20
|
-
labeller=None,
|
|
21
|
-
plot_kwargs=None,
|
|
22
|
-
ax=None,
|
|
23
|
-
backend=None,
|
|
24
|
-
backend_kwargs=None,
|
|
25
|
-
show=None,
|
|
26
|
-
):
|
|
27
|
-
r"""Summary plot for model comparison.
|
|
28
|
-
|
|
29
|
-
Models are compared based on their expected log pointwise predictive density (ELPD).
|
|
30
|
-
This plot is in the style of the one used in [2]_. Chapter 6 in the first edition
|
|
31
|
-
or 7 in the second.
|
|
32
|
-
|
|
33
|
-
Notes
|
|
34
|
-
-----
|
|
35
|
-
The ELPD is estimated either by Pareto smoothed importance sampling leave-one-out
|
|
36
|
-
cross-validation (LOO) or using the widely applicable information criterion (WAIC).
|
|
37
|
-
We recommend LOO in line with the work presented by [1]_.
|
|
38
|
-
|
|
39
|
-
Parameters
|
|
40
|
-
----------
|
|
41
|
-
comp_df : pandas.DataFrame
|
|
42
|
-
Result of the :func:`arviz.compare` method.
|
|
43
|
-
insample_dev : bool, default False
|
|
44
|
-
Plot in-sample ELPD, that is the value of the information criteria without the
|
|
45
|
-
penalization given by the effective number of parameters (p_loo or p_waic).
|
|
46
|
-
plot_standard_error : bool, default True
|
|
47
|
-
Plot the standard error of the ELPD.
|
|
48
|
-
plot_ic_diff : bool, default False
|
|
49
|
-
Plot standard error of the difference in ELPD between each model
|
|
50
|
-
and the top-ranked model.
|
|
51
|
-
order_by_rank : bool, default True
|
|
52
|
-
If True ensure the best model is used as reference.
|
|
53
|
-
legend : bool, default False
|
|
54
|
-
Add legend to figure.
|
|
55
|
-
figsize : (float, float), optional
|
|
56
|
-
If `None`, size is (6, num of models) inches.
|
|
57
|
-
title : bool, default True
|
|
58
|
-
Show a tittle with a description of how to interpret the plot.
|
|
59
|
-
textsize : float, optional
|
|
60
|
-
Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
|
|
61
|
-
on `figsize`.
|
|
62
|
-
labeller : Labeller, optional
|
|
63
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
64
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
65
|
-
plot_kwargs : dict, optional
|
|
66
|
-
Optional arguments for plot elements. Currently accepts 'color_ic',
|
|
67
|
-
'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse',
|
|
68
|
-
'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize'
|
|
69
|
-
ax : matplotlib_axes or bokeh_figure, optional
|
|
70
|
-
Matplotlib axes or bokeh figure.
|
|
71
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
72
|
-
Select plotting backend.
|
|
73
|
-
backend_kwargs : bool, optional
|
|
74
|
-
These are kwargs specific to the backend being used, passed to
|
|
75
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
76
|
-
For additional documentation check the plotting method of the backend.
|
|
77
|
-
show : bool, optional
|
|
78
|
-
Call backend show function.
|
|
79
|
-
|
|
80
|
-
Returns
|
|
81
|
-
-------
|
|
82
|
-
axes : matplotlib_axes or bokeh_figure
|
|
83
|
-
|
|
84
|
-
See Also
|
|
85
|
-
--------
|
|
86
|
-
plot_elpd : Plot pointwise elpd differences between two or more models.
|
|
87
|
-
compare : Compare models based on PSIS-LOO loo or WAIC waic cross-validation.
|
|
88
|
-
loo : Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
|
|
89
|
-
waic : Compute the widely applicable information criterion.
|
|
90
|
-
|
|
91
|
-
References
|
|
92
|
-
----------
|
|
93
|
-
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
|
|
94
|
-
cross-validation and WAIC https://arxiv.org/abs/1507.04544
|
|
95
|
-
|
|
96
|
-
.. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
|
|
97
|
-
R and Stan, Second edition, CRC Press.
|
|
98
|
-
|
|
99
|
-
Examples
|
|
100
|
-
--------
|
|
101
|
-
Show default compare plot
|
|
102
|
-
|
|
103
|
-
.. plot::
|
|
104
|
-
:context: close-figs
|
|
105
|
-
|
|
106
|
-
>>> import arviz as az
|
|
107
|
-
>>> model_compare = az.compare({'Centered 8 schools': az.load_arviz_data('centered_eight'),
|
|
108
|
-
>>> 'Non-centered 8 schools': az.load_arviz_data('non_centered_eight')})
|
|
109
|
-
>>> az.plot_compare(model_compare)
|
|
110
|
-
|
|
111
|
-
Include the in-sample ELDP
|
|
112
|
-
|
|
113
|
-
.. plot::
|
|
114
|
-
:context: close-figs
|
|
115
|
-
|
|
116
|
-
>>> az.plot_compare(model_compare, insample_dev=True)
|
|
117
|
-
|
|
118
|
-
"""
|
|
119
|
-
if plot_kwargs is None:
|
|
120
|
-
plot_kwargs = {}
|
|
121
|
-
|
|
122
|
-
if labeller is None:
|
|
123
|
-
labeller = BaseLabeller()
|
|
124
|
-
|
|
125
|
-
yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True)
|
|
126
|
-
yticks_pos[1::2] = yticks_pos[1::2] + step / 2
|
|
127
|
-
labels = [labeller.model_name_to_str(model_name) for model_name in comp_df.index]
|
|
128
|
-
|
|
129
|
-
if plot_ic_diff:
|
|
130
|
-
yticks_labels = [""] * len(yticks_pos)
|
|
131
|
-
yticks_labels[0] = labels[0]
|
|
132
|
-
yticks_labels[2::2] = labels[1:]
|
|
133
|
-
else:
|
|
134
|
-
yticks_labels = labels
|
|
135
|
-
|
|
136
|
-
_information_criterion = ["elpd_loo", "elpd_waic"]
|
|
137
|
-
column_index = [c.lower() for c in comp_df.columns]
|
|
138
|
-
for information_criterion in _information_criterion:
|
|
139
|
-
if information_criterion in column_index:
|
|
140
|
-
break
|
|
141
|
-
else:
|
|
142
|
-
raise ValueError(
|
|
143
|
-
"comp_df must contain one of the following "
|
|
144
|
-
f"information criterion: {_information_criterion}"
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
if order_by_rank:
|
|
148
|
-
comp_df.sort_values(by="rank", inplace=True)
|
|
149
|
-
|
|
150
|
-
compareplot_kwargs = dict(
|
|
151
|
-
ax=ax,
|
|
152
|
-
comp_df=comp_df,
|
|
153
|
-
legend=legend,
|
|
154
|
-
title=title,
|
|
155
|
-
figsize=figsize,
|
|
156
|
-
plot_ic_diff=plot_ic_diff,
|
|
157
|
-
plot_standard_error=plot_standard_error,
|
|
158
|
-
insample_dev=insample_dev,
|
|
159
|
-
yticks_pos=yticks_pos,
|
|
160
|
-
yticks_labels=yticks_labels,
|
|
161
|
-
plot_kwargs=plot_kwargs,
|
|
162
|
-
information_criterion=information_criterion,
|
|
163
|
-
textsize=textsize,
|
|
164
|
-
step=step,
|
|
165
|
-
backend_kwargs=backend_kwargs,
|
|
166
|
-
show=show,
|
|
167
|
-
)
|
|
168
|
-
|
|
169
|
-
if backend is None:
|
|
170
|
-
backend = rcParams["plot.backend"]
|
|
171
|
-
backend = backend.lower()
|
|
172
|
-
|
|
173
|
-
# TODO: Add backend kwargs
|
|
174
|
-
plot = get_plotting_function("plot_compare", "compareplot", backend)
|
|
175
|
-
ax = plot(**compareplot_kwargs)
|
|
176
|
-
|
|
177
|
-
return ax
|
arviz/plots/densityplot.py
DELETED
|
@@ -1,284 +0,0 @@
|
|
|
1
|
-
"""KDE and histogram plots for multiple variables."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
from ..data import convert_to_dataset
|
|
6
|
-
from ..labels import BaseLabeller
|
|
7
|
-
from ..sel_utils import (
|
|
8
|
-
xarray_var_iter,
|
|
9
|
-
)
|
|
10
|
-
from ..rcparams import rcParams
|
|
11
|
-
from ..utils import _var_names
|
|
12
|
-
from .plot_utils import default_grid, get_plotting_function
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
# pylint:disable-msg=too-many-function-args
|
|
16
|
-
def plot_density(
|
|
17
|
-
data,
|
|
18
|
-
group="posterior",
|
|
19
|
-
data_labels=None,
|
|
20
|
-
var_names=None,
|
|
21
|
-
filter_vars=None,
|
|
22
|
-
combine_dims=None,
|
|
23
|
-
transform=None,
|
|
24
|
-
hdi_prob=None,
|
|
25
|
-
point_estimate="auto",
|
|
26
|
-
colors="cycle",
|
|
27
|
-
outline=True,
|
|
28
|
-
hdi_markers="",
|
|
29
|
-
shade=0.0,
|
|
30
|
-
bw="default",
|
|
31
|
-
circular=False,
|
|
32
|
-
grid=None,
|
|
33
|
-
figsize=None,
|
|
34
|
-
textsize=None,
|
|
35
|
-
labeller=None,
|
|
36
|
-
ax=None,
|
|
37
|
-
backend=None,
|
|
38
|
-
backend_kwargs=None,
|
|
39
|
-
show=None,
|
|
40
|
-
):
|
|
41
|
-
r"""Generate KDE plots for continuous variables and histograms for discrete ones.
|
|
42
|
-
|
|
43
|
-
Plots are truncated at their 100*(1-alpha)% highest density intervals. Plots are grouped per
|
|
44
|
-
variable and colors assigned to models.
|
|
45
|
-
|
|
46
|
-
Parameters
|
|
47
|
-
----------
|
|
48
|
-
data : InferenceData or iterable of InferenceData
|
|
49
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object, or an Iterator
|
|
50
|
-
returning a sequence of such objects.
|
|
51
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
52
|
-
group : {"posterior", "prior"}, default "posterior"
|
|
53
|
-
Specifies which InferenceData group should be plotted. If "posterior", then the values
|
|
54
|
-
in `posterior_predictive` group are compared to the ones in `observed_data`, if "prior" then
|
|
55
|
-
the same comparison happens, but with the values in `prior_predictive` group.
|
|
56
|
-
data_labels : list of str, default None
|
|
57
|
-
List with names for the datasets passed as "data." Useful when plotting more than one
|
|
58
|
-
dataset. Must be the same shape as the data parameter.
|
|
59
|
-
var_names : list of str, optional
|
|
60
|
-
List of variables to plot. If multiple datasets are supplied and `var_names` is not None,
|
|
61
|
-
will print the same set of variables for each dataset. Defaults to None, which results in
|
|
62
|
-
all the variables being plotted.
|
|
63
|
-
filter_vars : {None, "like", "regex"}, default None
|
|
64
|
-
If `None` (default), interpret `var_names` as the real variables names. If "like",
|
|
65
|
-
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
66
|
-
interpret `var_names` as regular expressions on the real variables names. See
|
|
67
|
-
:ref:`this section <common_filter_vars>` for usage examples.
|
|
68
|
-
combine_dims : set_like of str, optional
|
|
69
|
-
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
|
|
70
|
-
See :ref:`this section <common_combine_dims>` for usage examples.
|
|
71
|
-
transform : callable
|
|
72
|
-
Function to transform data (defaults to `None` i.e. the identity function).
|
|
73
|
-
hdi_prob : float, default 0.94
|
|
74
|
-
Probability for the highest density interval. Should be in the interval (0, 1].
|
|
75
|
-
See :ref:`this section <common_hdi_prob>` for usage examples.
|
|
76
|
-
point_estimate : str, optional
|
|
77
|
-
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
|
|
78
|
-
Defaults to 'auto' i.e. it falls back to default set in ``rcParams``.
|
|
79
|
-
colors : str or list of str, optional
|
|
80
|
-
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
|
|
81
|
-
If the string is `cycle`, it will automatically choose a color per model from matplotlib's
|
|
82
|
-
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
|
|
83
|
-
models. Defaults to `cycle`.
|
|
84
|
-
outline : bool, default True
|
|
85
|
-
Use a line to draw KDEs and histograms.
|
|
86
|
-
hdi_markers : str
|
|
87
|
-
A valid `matplotlib.markers` like 'v', used to indicate the limits of the highest density
|
|
88
|
-
interval. Defaults to empty string (no marker).
|
|
89
|
-
shade : float, default 0
|
|
90
|
-
Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1
|
|
91
|
-
(opaque).
|
|
92
|
-
bw : float or str, optional
|
|
93
|
-
If numeric, indicates the bandwidth and must be positive.
|
|
94
|
-
If str, indicates the method to estimate the bandwidth and must be
|
|
95
|
-
one of "scott", "silverman", "isj" or "experimental" when `circular` is False
|
|
96
|
-
and "taylor" (for now) when `circular` is True.
|
|
97
|
-
Defaults to "default" which means "experimental" when variable is not circular
|
|
98
|
-
and "taylor" when it is.
|
|
99
|
-
circular : bool, default False
|
|
100
|
-
If True, it interprets the values passed are from a circular variable measured in radians
|
|
101
|
-
and a circular KDE is used. Only valid for 1D KDE.
|
|
102
|
-
grid : tuple, optional
|
|
103
|
-
Number of rows and columns. Defaults to ``None``, the rows and columns are
|
|
104
|
-
automatically inferred. See :ref:`this section <common_grid>` for usage examples.
|
|
105
|
-
figsize : (float, float), optional
|
|
106
|
-
Figure size. If `None` it will be defined automatically.
|
|
107
|
-
textsize : float, optional
|
|
108
|
-
Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
|
|
109
|
-
on `figsize`.
|
|
110
|
-
labeller : Labeller, optional
|
|
111
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
112
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
113
|
-
ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
|
|
114
|
-
A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
|
|
115
|
-
its own array of plot areas (and return it).
|
|
116
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
117
|
-
Select plotting backend.
|
|
118
|
-
backend_kwargs : dict, optional
|
|
119
|
-
These are kwargs specific to the backend being used, passed to
|
|
120
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
121
|
-
For additional documentation check the plotting method of the backend.
|
|
122
|
-
show : bool, optional
|
|
123
|
-
Call backend show function.
|
|
124
|
-
|
|
125
|
-
Returns
|
|
126
|
-
-------
|
|
127
|
-
axes : 2D ndarray of matplotlib_axes or bokeh_figure
|
|
128
|
-
|
|
129
|
-
See Also
|
|
130
|
-
--------
|
|
131
|
-
plot_dist : Plot distribution as histogram or kernel density estimates.
|
|
132
|
-
plot_posterior : Plot Posterior densities in the style of John K. Kruschke's book.
|
|
133
|
-
|
|
134
|
-
Examples
|
|
135
|
-
--------
|
|
136
|
-
Plot default density plot
|
|
137
|
-
|
|
138
|
-
.. plot::
|
|
139
|
-
:context: close-figs
|
|
140
|
-
|
|
141
|
-
>>> import arviz as az
|
|
142
|
-
>>> centered = az.load_arviz_data('centered_eight')
|
|
143
|
-
>>> non_centered = az.load_arviz_data('non_centered_eight')
|
|
144
|
-
>>> az.plot_density([centered, non_centered])
|
|
145
|
-
|
|
146
|
-
Plot variables in a 4x5 grid
|
|
147
|
-
|
|
148
|
-
.. plot::
|
|
149
|
-
:context: close-figs
|
|
150
|
-
|
|
151
|
-
>>> az.plot_density([centered, non_centered], grid=(4, 5))
|
|
152
|
-
|
|
153
|
-
Plot subset variables by specifying variable name exactly
|
|
154
|
-
|
|
155
|
-
.. plot::
|
|
156
|
-
:context: close-figs
|
|
157
|
-
|
|
158
|
-
>>> az.plot_density([centered, non_centered], var_names=["mu"])
|
|
159
|
-
|
|
160
|
-
Plot a specific `az.InferenceData` group
|
|
161
|
-
|
|
162
|
-
.. plot::
|
|
163
|
-
:context: close-figs
|
|
164
|
-
|
|
165
|
-
>>> az.plot_density([centered, non_centered], var_names=["mu"], group="prior")
|
|
166
|
-
|
|
167
|
-
Specify highest density interval
|
|
168
|
-
|
|
169
|
-
.. plot::
|
|
170
|
-
:context: close-figs
|
|
171
|
-
|
|
172
|
-
>>> az.plot_density([centered, non_centered], var_names=["mu"], hdi_prob=.5)
|
|
173
|
-
|
|
174
|
-
Shade plots and/or remove outlines
|
|
175
|
-
|
|
176
|
-
.. plot::
|
|
177
|
-
:context: close-figs
|
|
178
|
-
|
|
179
|
-
>>> az.plot_density([centered, non_centered], var_names=["mu"], outline=False, shade=.8)
|
|
180
|
-
|
|
181
|
-
Specify binwidth for kernel density estimation
|
|
182
|
-
|
|
183
|
-
.. plot::
|
|
184
|
-
:context: close-figs
|
|
185
|
-
|
|
186
|
-
>>> az.plot_density([centered, non_centered], var_names=["mu"], bw=.9)
|
|
187
|
-
"""
|
|
188
|
-
if isinstance(data, (list, tuple)):
|
|
189
|
-
datasets = [convert_to_dataset(datum, group=group) for datum in data]
|
|
190
|
-
else:
|
|
191
|
-
datasets = [convert_to_dataset(data, group=group)]
|
|
192
|
-
|
|
193
|
-
if transform is not None:
|
|
194
|
-
datasets = [transform(dataset) for dataset in datasets]
|
|
195
|
-
|
|
196
|
-
if labeller is None:
|
|
197
|
-
labeller = BaseLabeller()
|
|
198
|
-
|
|
199
|
-
var_names = _var_names(var_names, datasets, filter_vars)
|
|
200
|
-
|
|
201
|
-
n_data = len(datasets)
|
|
202
|
-
|
|
203
|
-
if data_labels is None:
|
|
204
|
-
data_labels = [f"{idx}" for idx in range(n_data)] if n_data > 1 else [""]
|
|
205
|
-
elif len(data_labels) != n_data:
|
|
206
|
-
raise ValueError(
|
|
207
|
-
f"The number of names for the models ({len(data_labels)}) "
|
|
208
|
-
f"does not match the number of models ({n_data})"
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
if hdi_prob is None:
|
|
212
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
213
|
-
elif not 1 >= hdi_prob > 0:
|
|
214
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
215
|
-
|
|
216
|
-
to_plot = [
|
|
217
|
-
list(xarray_var_iter(data, var_names, combined=True, skip_dims=combine_dims))
|
|
218
|
-
for data in datasets
|
|
219
|
-
]
|
|
220
|
-
all_labels = []
|
|
221
|
-
length_plotters = []
|
|
222
|
-
for plotters in to_plot:
|
|
223
|
-
length_plotters.append(len(plotters))
|
|
224
|
-
for var_name, selection, isel, _ in plotters:
|
|
225
|
-
label = labeller.make_label_vert(var_name, selection, isel)
|
|
226
|
-
if label not in all_labels:
|
|
227
|
-
all_labels.append(label)
|
|
228
|
-
length_plotters = len(all_labels)
|
|
229
|
-
max_plots = rcParams["plot.max_subplots"]
|
|
230
|
-
max_plots = length_plotters if max_plots is None else max_plots
|
|
231
|
-
if length_plotters > max_plots:
|
|
232
|
-
warnings.warn(
|
|
233
|
-
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
|
|
234
|
-
"of variables to plot ({len_plotters}) in plot_density, generating only "
|
|
235
|
-
"{max_plots} plots".format(max_plots=max_plots, len_plotters=length_plotters),
|
|
236
|
-
UserWarning,
|
|
237
|
-
)
|
|
238
|
-
all_labels = all_labels[:max_plots]
|
|
239
|
-
to_plot = [
|
|
240
|
-
[
|
|
241
|
-
(var_name, selection, values)
|
|
242
|
-
for var_name, selection, isel, values in plotters
|
|
243
|
-
if labeller.make_label_vert(var_name, selection, isel) in all_labels
|
|
244
|
-
]
|
|
245
|
-
for plotters in to_plot
|
|
246
|
-
]
|
|
247
|
-
length_plotters = max_plots
|
|
248
|
-
rows, cols = default_grid(length_plotters, grid=grid, max_cols=3)
|
|
249
|
-
|
|
250
|
-
if bw == "default":
|
|
251
|
-
bw = "taylor" if circular else "experimental"
|
|
252
|
-
|
|
253
|
-
plot_density_kwargs = dict(
|
|
254
|
-
ax=ax,
|
|
255
|
-
all_labels=all_labels,
|
|
256
|
-
to_plot=to_plot,
|
|
257
|
-
colors=colors,
|
|
258
|
-
bw=bw,
|
|
259
|
-
circular=circular,
|
|
260
|
-
figsize=figsize,
|
|
261
|
-
length_plotters=length_plotters,
|
|
262
|
-
rows=rows,
|
|
263
|
-
cols=cols,
|
|
264
|
-
textsize=textsize,
|
|
265
|
-
labeller=labeller,
|
|
266
|
-
hdi_prob=hdi_prob,
|
|
267
|
-
point_estimate=point_estimate,
|
|
268
|
-
hdi_markers=hdi_markers,
|
|
269
|
-
outline=outline,
|
|
270
|
-
shade=shade,
|
|
271
|
-
n_data=n_data,
|
|
272
|
-
data_labels=data_labels,
|
|
273
|
-
backend_kwargs=backend_kwargs,
|
|
274
|
-
show=show,
|
|
275
|
-
)
|
|
276
|
-
|
|
277
|
-
if backend is None:
|
|
278
|
-
backend = rcParams["plot.backend"]
|
|
279
|
-
backend = backend.lower()
|
|
280
|
-
|
|
281
|
-
# TODO: Add backend kwargs
|
|
282
|
-
plot = get_plotting_function("plot_density", "densityplot", backend)
|
|
283
|
-
ax = plot(**plot_density_kwargs)
|
|
284
|
-
return ax
|
|
@@ -1,197 +0,0 @@
|
|
|
1
|
-
"""Density Comparison plot."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from ..labels import BaseLabeller
|
|
5
|
-
from ..rcparams import rcParams
|
|
6
|
-
from ..utils import _var_names, get_coords
|
|
7
|
-
from .plot_utils import get_plotting_function
|
|
8
|
-
from ..sel_utils import xarray_var_iter, xarray_sel_iter
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def plot_dist_comparison(
|
|
12
|
-
data,
|
|
13
|
-
kind="latent",
|
|
14
|
-
figsize=None,
|
|
15
|
-
textsize=None,
|
|
16
|
-
var_names=None,
|
|
17
|
-
coords=None,
|
|
18
|
-
combine_dims=None,
|
|
19
|
-
transform=None,
|
|
20
|
-
legend=True,
|
|
21
|
-
labeller=None,
|
|
22
|
-
ax=None,
|
|
23
|
-
prior_kwargs=None,
|
|
24
|
-
posterior_kwargs=None,
|
|
25
|
-
observed_kwargs=None,
|
|
26
|
-
backend=None,
|
|
27
|
-
backend_kwargs=None,
|
|
28
|
-
show=None,
|
|
29
|
-
):
|
|
30
|
-
r"""Plot to compare fitted and unfitted distributions.
|
|
31
|
-
|
|
32
|
-
The resulting plots will show the compared distributions both on
|
|
33
|
-
separate axes (particularly useful when one of them is substantially tighter
|
|
34
|
-
than another), and plotted together, displaying a grid of three plots per
|
|
35
|
-
distribution.
|
|
36
|
-
|
|
37
|
-
Parameters
|
|
38
|
-
----------
|
|
39
|
-
data : InferenceData
|
|
40
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
41
|
-
containing the posterior/prior data. Refer to documentation of
|
|
42
|
-
:func:`arviz.convert_to_dataset` for details.
|
|
43
|
-
kind : {"latent", "observed"}, default "latent"
|
|
44
|
-
kind of plot to display The "latent" option includes {"prior", "posterior"},
|
|
45
|
-
and the "observed" option includes
|
|
46
|
-
{"observed_data", "prior_predictive", "posterior_predictive"}.
|
|
47
|
-
figsize : (float, float), optional
|
|
48
|
-
Figure size. If ``None`` it will be defined automatically.
|
|
49
|
-
textsize : float
|
|
50
|
-
Text size scaling factor for labels, titles and lines. If ``None`` it will be
|
|
51
|
-
autoscaled based on `figsize`.
|
|
52
|
-
var_names : str, list, list of lists, optional
|
|
53
|
-
if str, plot the variable. if list, plot all the variables in list
|
|
54
|
-
of all groups. if list of lists, plot the vars of groups in respective lists.
|
|
55
|
-
See :ref:`this section <common_var_names>` for usage examples.
|
|
56
|
-
coords : dict
|
|
57
|
-
Dictionary mapping dimensions to selected coordinates to be plotted.
|
|
58
|
-
Dimensions without a mapping specified will include all coordinates for
|
|
59
|
-
that dimension. See :ref:`this section <common_coords>` for usage examples.
|
|
60
|
-
combine_dims : set_like of str, optional
|
|
61
|
-
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
|
|
62
|
-
See :ref:`this section <common_combine_dims>` for usage examples.
|
|
63
|
-
transform : callable
|
|
64
|
-
Function to transform data (defaults to `None` i.e. the identity function).
|
|
65
|
-
legend : bool
|
|
66
|
-
Add legend to figure. By default True.
|
|
67
|
-
labeller : Labeller, optional
|
|
68
|
-
Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
|
|
69
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
70
|
-
ax : (nvars, 3) array-like of matplotlib_axes, optional
|
|
71
|
-
Matplotlib axes: The ax argument should have shape (nvars, 3), where the
|
|
72
|
-
last column is for the combined before/after plots and columns 0 and 1 are
|
|
73
|
-
for the before and after plots, respectively.
|
|
74
|
-
prior_kwargs : dicts, optional
|
|
75
|
-
Additional keywords passed to :func:`arviz.plot_dist` for prior/predictive groups.
|
|
76
|
-
posterior_kwargs : dicts, optional
|
|
77
|
-
Additional keywords passed to :func:`arviz.plot_dist` for posterior/predictive groups.
|
|
78
|
-
observed_kwargs : dicts, optional
|
|
79
|
-
Additional keywords passed to :func:`arviz.plot_dist` for observed_data group.
|
|
80
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
81
|
-
Select plotting backend.
|
|
82
|
-
backend_kwargs : dict, optional
|
|
83
|
-
These are kwargs specific to the backend being used, passed to
|
|
84
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
85
|
-
For additional documentation check the plotting method of the backend.
|
|
86
|
-
show : bool, optional
|
|
87
|
-
Call backend show function.
|
|
88
|
-
|
|
89
|
-
Returns
|
|
90
|
-
-------
|
|
91
|
-
axes : 2D ndarray of matplotlib_axes
|
|
92
|
-
Returned object will have shape (nvars, 3),
|
|
93
|
-
where the last column is the combined plot and the first columns are the single plots.
|
|
94
|
-
|
|
95
|
-
See Also
|
|
96
|
-
--------
|
|
97
|
-
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
|
|
98
|
-
|
|
99
|
-
Examples
|
|
100
|
-
--------
|
|
101
|
-
Plot the prior/posterior plot for specified vars and coords.
|
|
102
|
-
|
|
103
|
-
.. plot::
|
|
104
|
-
:context: close-figs
|
|
105
|
-
|
|
106
|
-
>>> import arviz as az
|
|
107
|
-
>>> data = az.load_arviz_data('rugby')
|
|
108
|
-
>>> az.plot_dist_comparison(data, var_names=["defs"], coords={"team" : ["Italy"]})
|
|
109
|
-
|
|
110
|
-
"""
|
|
111
|
-
all_groups = ["prior", "posterior"]
|
|
112
|
-
|
|
113
|
-
if kind == "observed":
|
|
114
|
-
all_groups = ["observed_data", "prior_predictive", "posterior_predictive"]
|
|
115
|
-
|
|
116
|
-
if coords is None:
|
|
117
|
-
coords = {}
|
|
118
|
-
|
|
119
|
-
if labeller is None:
|
|
120
|
-
labeller = BaseLabeller()
|
|
121
|
-
|
|
122
|
-
datasets = []
|
|
123
|
-
groups = []
|
|
124
|
-
for group in all_groups:
|
|
125
|
-
try:
|
|
126
|
-
datasets.append(getattr(data, group))
|
|
127
|
-
groups.append(group)
|
|
128
|
-
except: # pylint: disable=bare-except
|
|
129
|
-
pass
|
|
130
|
-
|
|
131
|
-
if var_names is None:
|
|
132
|
-
var_names = list(datasets[0].data_vars)
|
|
133
|
-
|
|
134
|
-
if isinstance(var_names, str):
|
|
135
|
-
var_names = [var_names]
|
|
136
|
-
|
|
137
|
-
if isinstance(var_names[0], str):
|
|
138
|
-
var_names = [var_names for _ in datasets]
|
|
139
|
-
|
|
140
|
-
var_names = [_var_names(vars, dataset) for vars, dataset in zip(var_names, datasets)]
|
|
141
|
-
|
|
142
|
-
if transform is not None:
|
|
143
|
-
datasets = [transform(dataset) for dataset in datasets]
|
|
144
|
-
|
|
145
|
-
datasets = get_coords(datasets, coords)
|
|
146
|
-
len_plots = rcParams["plot.max_subplots"] // (len(groups) + 1)
|
|
147
|
-
len_plots = len_plots or 1
|
|
148
|
-
dc_plotters = [
|
|
149
|
-
list(xarray_var_iter(data, var_names=var, combined=True, skip_dims=combine_dims))[
|
|
150
|
-
:len_plots
|
|
151
|
-
]
|
|
152
|
-
for data, var in zip(datasets, var_names)
|
|
153
|
-
]
|
|
154
|
-
|
|
155
|
-
total_plots = sum(
|
|
156
|
-
1 for _ in xarray_sel_iter(datasets[0], var_names=var_names[0], combined=True)
|
|
157
|
-
) * (len(groups) + 1)
|
|
158
|
-
maxplots = len(dc_plotters[0]) * (len(groups) + 1)
|
|
159
|
-
|
|
160
|
-
if total_plots > rcParams["plot.max_subplots"]:
|
|
161
|
-
warnings.warn(
|
|
162
|
-
"rcParams['plot.max_subplots'] ({rcParam}) is smaller than the number "
|
|
163
|
-
"of subplots to plot ({len_plotters}), generating only {max_plots} "
|
|
164
|
-
"plots".format(
|
|
165
|
-
rcParam=rcParams["plot.max_subplots"], len_plotters=total_plots, max_plots=maxplots
|
|
166
|
-
),
|
|
167
|
-
UserWarning,
|
|
168
|
-
)
|
|
169
|
-
|
|
170
|
-
nvars = len(dc_plotters[0])
|
|
171
|
-
ngroups = len(groups)
|
|
172
|
-
|
|
173
|
-
distcomparisonplot_kwargs = dict(
|
|
174
|
-
ax=ax,
|
|
175
|
-
nvars=nvars,
|
|
176
|
-
ngroups=ngroups,
|
|
177
|
-
figsize=figsize,
|
|
178
|
-
dc_plotters=dc_plotters,
|
|
179
|
-
legend=legend,
|
|
180
|
-
groups=groups,
|
|
181
|
-
textsize=textsize,
|
|
182
|
-
labeller=labeller,
|
|
183
|
-
prior_kwargs=prior_kwargs,
|
|
184
|
-
posterior_kwargs=posterior_kwargs,
|
|
185
|
-
observed_kwargs=observed_kwargs,
|
|
186
|
-
backend_kwargs=backend_kwargs,
|
|
187
|
-
show=show,
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
if backend is None:
|
|
191
|
-
backend = rcParams["plot.backend"]
|
|
192
|
-
backend = backend.lower()
|
|
193
|
-
|
|
194
|
-
# TODO: Add backend kwargs
|
|
195
|
-
plot = get_plotting_function("plot_dist_comparison", "distcomparisonplot", backend)
|
|
196
|
-
axes = plot(**distcomparisonplot_kwargs)
|
|
197
|
-
return axes
|