arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/essplot.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
|
1
|
-
"""Plot quantile or local effective sample sizes."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
import xarray as xr
|
|
5
|
-
|
|
6
|
-
from ..data import convert_to_dataset
|
|
7
|
-
from ..labels import BaseLabeller
|
|
8
|
-
from ..rcparams import rcParams
|
|
9
|
-
from ..sel_utils import xarray_var_iter
|
|
10
|
-
from ..stats import ess
|
|
11
|
-
from ..utils import _var_names, get_coords
|
|
12
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def plot_ess(
|
|
16
|
-
idata,
|
|
17
|
-
var_names=None,
|
|
18
|
-
filter_vars=None,
|
|
19
|
-
kind="local",
|
|
20
|
-
relative=False,
|
|
21
|
-
coords=None,
|
|
22
|
-
figsize=None,
|
|
23
|
-
grid=None,
|
|
24
|
-
textsize=None,
|
|
25
|
-
rug=False,
|
|
26
|
-
rug_kind="diverging",
|
|
27
|
-
n_points=20,
|
|
28
|
-
extra_methods=False,
|
|
29
|
-
min_ess=400,
|
|
30
|
-
labeller=None,
|
|
31
|
-
ax=None,
|
|
32
|
-
extra_kwargs=None,
|
|
33
|
-
text_kwargs=None,
|
|
34
|
-
hline_kwargs=None,
|
|
35
|
-
rug_kwargs=None,
|
|
36
|
-
backend=None,
|
|
37
|
-
backend_kwargs=None,
|
|
38
|
-
show=None,
|
|
39
|
-
**kwargs,
|
|
40
|
-
):
|
|
41
|
-
r"""Generate quantile, local, or evolution ESS plots.
|
|
42
|
-
|
|
43
|
-
The local and the quantile ESS plots are recommended for checking
|
|
44
|
-
that there are enough samples for all the explored regions of the
|
|
45
|
-
parameter space. Checking local and quantile ESS is particularly
|
|
46
|
-
relevant when working with HDI intervals as opposed to ESS bulk,
|
|
47
|
-
which is suitable for point estimates.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
idata : InferenceData
|
|
52
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
53
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
54
|
-
var_names : list of str, optional
|
|
55
|
-
Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
|
|
56
|
-
them from the plot. See :ref:`this section <common_var_names>` for usage examples.
|
|
57
|
-
filter_vars : {None, "like", "regex"}, default None
|
|
58
|
-
If `None` (default), interpret `var_names` as the real variables names. If "like",
|
|
59
|
-
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
60
|
-
interpret `var_names` as regular expressions on the real variables names. See
|
|
61
|
-
:ref:`this section <common_filter_vars>` for usage examples.
|
|
62
|
-
kind : {"local", "quantile", "evolution"}, default "local"
|
|
63
|
-
Specify the kind of plot:
|
|
64
|
-
|
|
65
|
-
* The ``kind="local"`` argument generates the ESS' local efficiency for
|
|
66
|
-
estimating quantiles of a desired posterior.
|
|
67
|
-
* The ``kind="quantile"`` argument generates the ESS' local efficiency
|
|
68
|
-
for estimating small-interval probability of a desired posterior.
|
|
69
|
-
* The ``kind="evolution"`` argument generates the estimated ESS'
|
|
70
|
-
with incrised number of iterations of a desired posterior.
|
|
71
|
-
|
|
72
|
-
relative : bool, default False
|
|
73
|
-
Show relative ess in plot ``ress = ess / N``.
|
|
74
|
-
coords : dict, optional
|
|
75
|
-
Coordinates of `var_names` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
|
|
76
|
-
See :ref:`this section <common_coords>` for usage examples.
|
|
77
|
-
grid : tuple, optional
|
|
78
|
-
Number of rows and columns. By default, the rows and columns are
|
|
79
|
-
automatically inferred. See :ref:`this section <common_grid>` for usage examples.
|
|
80
|
-
figsize : (float, float), optional
|
|
81
|
-
Figure size. If ``None`` it will be defined automatically.
|
|
82
|
-
textsize : float, optional
|
|
83
|
-
Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled
|
|
84
|
-
based on `figsize`.
|
|
85
|
-
rug : bool, default False
|
|
86
|
-
Add a `rug plot <https://en.wikipedia.org/wiki/Rug_plot>`_ for a specific subset of values.
|
|
87
|
-
rug_kind : str, default "diverging"
|
|
88
|
-
Variable in sample stats to use as rug mask. Must be a boolean variable.
|
|
89
|
-
n_points : int, default 20
|
|
90
|
-
Number of points for which to plot their quantile/local ess or number of subsets
|
|
91
|
-
in the evolution plot.
|
|
92
|
-
extra_methods : bool, default False
|
|
93
|
-
Plot mean and sd ESS as horizontal lines. Not taken into account if ``kind = 'evolution'``.
|
|
94
|
-
min_ess : int, default 400
|
|
95
|
-
Minimum number of ESS desired. If ``relative=True`` the line is plotted at
|
|
96
|
-
``min_ess / n_samples`` for local and quantile kinds and as a curve following
|
|
97
|
-
the ``min_ess / n`` dependency in evolution kind.
|
|
98
|
-
labeller : Labeller, optional
|
|
99
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
100
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
101
|
-
ax : 2D array-like of matplotlib_axes or bokeh_figure, optional
|
|
102
|
-
A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
|
|
103
|
-
its own array of plot areas (and return it).
|
|
104
|
-
extra_kwargs : dict, optional
|
|
105
|
-
If evolution plot, `extra_kwargs` is used to plot ess tail and differentiate it
|
|
106
|
-
from ess bulk. Otherwise, passed to extra methods lines.
|
|
107
|
-
text_kwargs : dict, optional
|
|
108
|
-
Only taken into account when ``extra_methods=True``. kwargs passed to ax.annotate
|
|
109
|
-
for extra methods lines labels. It accepts the additional
|
|
110
|
-
key ``x`` to set ``xy=(text_kwargs["x"], mcse)``
|
|
111
|
-
hline_kwargs : dict, optional
|
|
112
|
-
kwargs passed to :func:`~matplotlib.axes.Axes.axhline` or to :class:`~bokeh.models.Span`
|
|
113
|
-
depending on the backend for the horizontal minimum ESS line.
|
|
114
|
-
For relative ess evolution plots the kwargs are passed to
|
|
115
|
-
:func:`~matplotlib.axes.Axes.plot` or to :class:`~bokeh.plotting.figure.line`
|
|
116
|
-
rug_kwargs : dict
|
|
117
|
-
kwargs passed to rug plot.
|
|
118
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
119
|
-
Select plotting backend.
|
|
120
|
-
backend_kwargs : dict, optional
|
|
121
|
-
These are kwargs specific to the backend being used, passed to
|
|
122
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
123
|
-
For additional documentation check the plotting method of the backend.
|
|
124
|
-
show : bool, optional
|
|
125
|
-
Call backend show function.
|
|
126
|
-
**kwargs
|
|
127
|
-
Passed as-is to :meth:`mpl:matplotlib.axes.Axes.hist` or
|
|
128
|
-
:meth:`mpl:matplotlib.axes.Axes.plot` function depending on the
|
|
129
|
-
value of `kind`.
|
|
130
|
-
|
|
131
|
-
Returns
|
|
132
|
-
-------
|
|
133
|
-
axes : matplotlib_axes or bokeh_figure
|
|
134
|
-
|
|
135
|
-
See Also
|
|
136
|
-
--------
|
|
137
|
-
ess : Calculate estimate of the effective sample size.
|
|
138
|
-
|
|
139
|
-
References
|
|
140
|
-
----------
|
|
141
|
-
.. [1] Vehtari et al. (2021). Rank-normalization, folding, and
|
|
142
|
-
localization: An improved Rhat for assessing convergence of
|
|
143
|
-
MCMC. Bayesian analysis, 16(2):667-718.
|
|
144
|
-
|
|
145
|
-
Examples
|
|
146
|
-
--------
|
|
147
|
-
Plot local ESS.
|
|
148
|
-
|
|
149
|
-
.. plot::
|
|
150
|
-
:context: close-figs
|
|
151
|
-
|
|
152
|
-
>>> import arviz as az
|
|
153
|
-
>>> idata = az.load_arviz_data("centered_eight")
|
|
154
|
-
>>> coords = {"school": ["Choate", "Lawrenceville"]}
|
|
155
|
-
>>> az.plot_ess(
|
|
156
|
-
... idata, kind="local", var_names=["mu", "theta"], coords=coords
|
|
157
|
-
... )
|
|
158
|
-
|
|
159
|
-
Plot ESS evolution as the number of samples increase. When the model is converging properly,
|
|
160
|
-
both lines in this plot should be roughly linear.
|
|
161
|
-
|
|
162
|
-
.. plot::
|
|
163
|
-
:context: close-figs
|
|
164
|
-
|
|
165
|
-
>>> az.plot_ess(
|
|
166
|
-
... idata, kind="evolution", var_names=["mu", "theta"], coords=coords
|
|
167
|
-
... )
|
|
168
|
-
|
|
169
|
-
Customize local ESS plot to look like reference paper.
|
|
170
|
-
|
|
171
|
-
.. plot::
|
|
172
|
-
:context: close-figs
|
|
173
|
-
|
|
174
|
-
>>> az.plot_ess(
|
|
175
|
-
... idata, kind="local", var_names=["mu"], drawstyle="steps-mid", color="k",
|
|
176
|
-
... linestyle="-", marker=None, rug=True, rug_kwargs={"color": "r"}
|
|
177
|
-
... )
|
|
178
|
-
|
|
179
|
-
Customize ESS evolution plot to look like reference paper.
|
|
180
|
-
|
|
181
|
-
.. plot::
|
|
182
|
-
:context: close-figs
|
|
183
|
-
|
|
184
|
-
>>> extra_kwargs = {"color": "lightsteelblue"}
|
|
185
|
-
>>> az.plot_ess(
|
|
186
|
-
... idata, kind="evolution", var_names=["mu"],
|
|
187
|
-
... color="royalblue", extra_kwargs=extra_kwargs
|
|
188
|
-
... )
|
|
189
|
-
|
|
190
|
-
"""
|
|
191
|
-
valid_kinds = ("local", "quantile", "evolution")
|
|
192
|
-
kind = kind.lower()
|
|
193
|
-
if kind not in valid_kinds:
|
|
194
|
-
raise ValueError(f"Invalid kind, kind must be one of {valid_kinds} not {kind}")
|
|
195
|
-
|
|
196
|
-
if coords is None:
|
|
197
|
-
coords = {}
|
|
198
|
-
if "chain" in coords or "draw" in coords:
|
|
199
|
-
raise ValueError("chain and draw are invalid coordinates for this kind of plot")
|
|
200
|
-
if labeller is None:
|
|
201
|
-
labeller = BaseLabeller()
|
|
202
|
-
extra_methods = False if kind == "evolution" else extra_methods
|
|
203
|
-
|
|
204
|
-
data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
|
|
205
|
-
var_names = _var_names(var_names, data, filter_vars)
|
|
206
|
-
n_draws = data.sizes["draw"]
|
|
207
|
-
n_samples = n_draws * data.sizes["chain"]
|
|
208
|
-
|
|
209
|
-
ess_tail_dataset = None
|
|
210
|
-
mean_ess = None
|
|
211
|
-
sd_ess = None
|
|
212
|
-
|
|
213
|
-
if kind == "quantile":
|
|
214
|
-
probs = np.linspace(1 / n_points, 1 - 1 / n_points, n_points)
|
|
215
|
-
xdata = probs
|
|
216
|
-
ylabel = "{} for quantiles"
|
|
217
|
-
ess_dataset = xr.concat(
|
|
218
|
-
[
|
|
219
|
-
ess(data, var_names=var_names, relative=relative, method="quantile", prob=p)
|
|
220
|
-
for p in probs
|
|
221
|
-
],
|
|
222
|
-
dim="ess_dim",
|
|
223
|
-
)
|
|
224
|
-
elif kind == "local":
|
|
225
|
-
probs = np.linspace(0, 1, n_points, endpoint=False)
|
|
226
|
-
xdata = probs
|
|
227
|
-
ylabel = "{} for small intervals"
|
|
228
|
-
ess_dataset = xr.concat(
|
|
229
|
-
[
|
|
230
|
-
ess(
|
|
231
|
-
data,
|
|
232
|
-
var_names=var_names,
|
|
233
|
-
relative=relative,
|
|
234
|
-
method="local",
|
|
235
|
-
prob=[p, p + 1 / n_points],
|
|
236
|
-
)
|
|
237
|
-
for p in probs
|
|
238
|
-
],
|
|
239
|
-
dim="ess_dim",
|
|
240
|
-
)
|
|
241
|
-
else:
|
|
242
|
-
first_draw = data.draw.values[0]
|
|
243
|
-
ylabel = "{}"
|
|
244
|
-
xdata = np.linspace(n_samples / n_points, n_samples, n_points)
|
|
245
|
-
draw_divisions = np.linspace(n_draws // n_points, n_draws, n_points, dtype=int)
|
|
246
|
-
ess_dataset = xr.concat(
|
|
247
|
-
[
|
|
248
|
-
ess(
|
|
249
|
-
data.sel(draw=slice(first_draw + draw_div)),
|
|
250
|
-
var_names=var_names,
|
|
251
|
-
relative=relative,
|
|
252
|
-
method="bulk",
|
|
253
|
-
)
|
|
254
|
-
for draw_div in draw_divisions
|
|
255
|
-
],
|
|
256
|
-
dim="ess_dim",
|
|
257
|
-
)
|
|
258
|
-
ess_tail_dataset = xr.concat(
|
|
259
|
-
[
|
|
260
|
-
ess(
|
|
261
|
-
data.sel(draw=slice(first_draw + draw_div)),
|
|
262
|
-
var_names=var_names,
|
|
263
|
-
relative=relative,
|
|
264
|
-
method="tail",
|
|
265
|
-
)
|
|
266
|
-
for draw_div in draw_divisions
|
|
267
|
-
],
|
|
268
|
-
dim="ess_dim",
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
plotters = filter_plotters_list(
|
|
272
|
-
list(xarray_var_iter(ess_dataset, var_names=var_names, skip_dims={"ess_dim"})), "plot_ess"
|
|
273
|
-
)
|
|
274
|
-
length_plotters = len(plotters)
|
|
275
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
276
|
-
|
|
277
|
-
if extra_methods:
|
|
278
|
-
mean_ess = ess(data, var_names=var_names, method="mean", relative=relative)
|
|
279
|
-
sd_ess = ess(data, var_names=var_names, method="sd", relative=relative)
|
|
280
|
-
|
|
281
|
-
essplot_kwargs = dict(
|
|
282
|
-
ax=ax,
|
|
283
|
-
plotters=plotters,
|
|
284
|
-
xdata=xdata,
|
|
285
|
-
ess_tail_dataset=ess_tail_dataset,
|
|
286
|
-
mean_ess=mean_ess,
|
|
287
|
-
sd_ess=sd_ess,
|
|
288
|
-
idata=idata,
|
|
289
|
-
data=data,
|
|
290
|
-
kind=kind,
|
|
291
|
-
extra_methods=extra_methods,
|
|
292
|
-
textsize=textsize,
|
|
293
|
-
rows=rows,
|
|
294
|
-
cols=cols,
|
|
295
|
-
figsize=figsize,
|
|
296
|
-
kwargs=kwargs,
|
|
297
|
-
extra_kwargs=extra_kwargs,
|
|
298
|
-
text_kwargs=text_kwargs,
|
|
299
|
-
n_samples=n_samples,
|
|
300
|
-
relative=relative,
|
|
301
|
-
min_ess=min_ess,
|
|
302
|
-
labeller=labeller,
|
|
303
|
-
ylabel=ylabel,
|
|
304
|
-
rug=rug,
|
|
305
|
-
rug_kind=rug_kind,
|
|
306
|
-
rug_kwargs=rug_kwargs,
|
|
307
|
-
hline_kwargs=hline_kwargs,
|
|
308
|
-
backend_kwargs=backend_kwargs,
|
|
309
|
-
show=show,
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
if backend is None:
|
|
313
|
-
backend = rcParams["plot.backend"]
|
|
314
|
-
backend = backend.lower()
|
|
315
|
-
|
|
316
|
-
# TODO: Add backend kwargs
|
|
317
|
-
plot = get_plotting_function("plot_ess", "essplot", backend)
|
|
318
|
-
ax = plot(**essplot_kwargs)
|
|
319
|
-
return ax
|
arviz/plots/forestplot.py
DELETED
|
@@ -1,304 +0,0 @@
|
|
|
1
|
-
"""Forest plot."""
|
|
2
|
-
|
|
3
|
-
from ..data import convert_to_dataset
|
|
4
|
-
from ..labels import BaseLabeller, NoModelLabeller
|
|
5
|
-
from ..rcparams import rcParams
|
|
6
|
-
from ..utils import _var_names, get_coords
|
|
7
|
-
from .plot_utils import get_plotting_function
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def plot_forest(
|
|
11
|
-
data,
|
|
12
|
-
kind="forestplot",
|
|
13
|
-
model_names=None,
|
|
14
|
-
var_names=None,
|
|
15
|
-
filter_vars=None,
|
|
16
|
-
transform=None,
|
|
17
|
-
coords=None,
|
|
18
|
-
combined=False,
|
|
19
|
-
combine_dims=None,
|
|
20
|
-
hdi_prob=None,
|
|
21
|
-
rope=None,
|
|
22
|
-
quartiles=True,
|
|
23
|
-
ess=False,
|
|
24
|
-
r_hat=False,
|
|
25
|
-
colors="cycle",
|
|
26
|
-
textsize=None,
|
|
27
|
-
linewidth=None,
|
|
28
|
-
markersize=None,
|
|
29
|
-
legend=True,
|
|
30
|
-
labeller=None,
|
|
31
|
-
ridgeplot_alpha=None,
|
|
32
|
-
ridgeplot_overlap=2,
|
|
33
|
-
ridgeplot_kind="auto",
|
|
34
|
-
ridgeplot_truncate=True,
|
|
35
|
-
ridgeplot_quantiles=None,
|
|
36
|
-
figsize=None,
|
|
37
|
-
ax=None,
|
|
38
|
-
backend=None,
|
|
39
|
-
backend_config=None,
|
|
40
|
-
backend_kwargs=None,
|
|
41
|
-
show=None,
|
|
42
|
-
):
|
|
43
|
-
r"""Forest plot to compare HDI intervals from a number of distributions.
|
|
44
|
-
|
|
45
|
-
Generate forest or ridge plots to compare distributions from a model or list of models.
|
|
46
|
-
Additionally, the function can display effective sample sizes (ess) and Rhats to visualize
|
|
47
|
-
convergence diagnostics alongside the distributions.
|
|
48
|
-
|
|
49
|
-
Parameters
|
|
50
|
-
----------
|
|
51
|
-
data : InferenceData
|
|
52
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
53
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
54
|
-
kind : {"forestplot", "ridgeplot"}, default "forestplot"
|
|
55
|
-
Specify the kind of plot:
|
|
56
|
-
|
|
57
|
-
* The ``kind="forestplot"`` generates credible intervals, where the central points are the
|
|
58
|
-
estimated posterior median, the thick lines are the central quartiles, and the thin lines
|
|
59
|
-
represent the :math:`100\times(hdi\_prob)\%` highest density intervals.
|
|
60
|
-
* The ``kind="ridgeplot"`` option generates density plots (kernel density estimate or
|
|
61
|
-
histograms) in the same graph. Ridge plots can be configured to have different overlap,
|
|
62
|
-
truncation bounds and quantile markers.
|
|
63
|
-
|
|
64
|
-
model_names : list of str, optional
|
|
65
|
-
List with names for the models in the list of data. Useful when plotting more that one
|
|
66
|
-
dataset.
|
|
67
|
-
var_names : list of str, optional
|
|
68
|
-
Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
|
|
69
|
-
them from the plot. See :ref:`this section <common_var_names>` for usage examples.
|
|
70
|
-
combine_dims : set_like of str, optional
|
|
71
|
-
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
|
|
72
|
-
See :ref:`this section <common_combine_dims>` for usage examples.
|
|
73
|
-
filter_vars : {None, "like", "regex"}, default None
|
|
74
|
-
If `None` (default), interpret `var_names` as the real variables names. If "like",
|
|
75
|
-
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
76
|
-
interpret `var_names` as regular expressions on the real variables names. See
|
|
77
|
-
:ref:`this section <common_filter_vars>` for usage examples.
|
|
78
|
-
transform : callable or dict, optional
|
|
79
|
-
Function to transform the data. Defaults to None, i.e., the identity function.
|
|
80
|
-
coords : dict, optional
|
|
81
|
-
Coordinates of ``var_names`` to be plotted. Passed to :meth:`xarray.Dataset.sel`.
|
|
82
|
-
See :ref:`this section <common_coords>` for usage examples.
|
|
83
|
-
combined : bool, default False
|
|
84
|
-
Flag for combining multiple chains into a single chain. If False, chains will
|
|
85
|
-
be plotted separately. See :ref:`this section <common_combine>` for usage examples.
|
|
86
|
-
hdi_prob : float, default 0.94
|
|
87
|
-
Plots highest posterior density interval for chosen percentage of density.
|
|
88
|
-
See :ref:`this section <common_ hdi_prob>` for usage examples.
|
|
89
|
-
rope : list, tuple or dictionary of {str : tuples or lists}, optional
|
|
90
|
-
A dictionary of tuples with the lower and upper values of the Region Of Practical
|
|
91
|
-
Equivalence. See :ref:`this section <common_rope>` for usage examples.
|
|
92
|
-
quartiles : bool, default True
|
|
93
|
-
Flag for plotting the interquartile range, in addition to the ``hdi_prob`` intervals.
|
|
94
|
-
r_hat : bool, default False
|
|
95
|
-
Flag for plotting Split R-hat statistics. Requires 2 or more chains.
|
|
96
|
-
ess : bool, default False
|
|
97
|
-
Flag for plotting the effective sample size.
|
|
98
|
-
colors : list or string, optional
|
|
99
|
-
list with valid matplotlib colors, one color per model. Alternative a string can be passed.
|
|
100
|
-
If the string is `cycle`, it will automatically chose a color per model from the matplotlibs
|
|
101
|
-
cycle. If a single color is passed, eg 'k', 'C2', 'red' this color will be used for all
|
|
102
|
-
models. Defaults to 'cycle'.
|
|
103
|
-
textsize : float, optional
|
|
104
|
-
Text size scaling factor for labels, titles and lines. If `None` it will be autoscaled based
|
|
105
|
-
on ``figsize``.
|
|
106
|
-
linewidth : int, optional
|
|
107
|
-
Line width throughout. If `None` it will be autoscaled based on ``figsize``.
|
|
108
|
-
markersize : int, optional
|
|
109
|
-
Markersize throughout. If `None` it will be autoscaled based on ``figsize``.
|
|
110
|
-
legend : bool, optional
|
|
111
|
-
Show a legend with the color encoded model information.
|
|
112
|
-
Defaults to True, if there are multiple models.
|
|
113
|
-
labeller : Labeller, optional
|
|
114
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
115
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
116
|
-
ridgeplot_alpha: float, optional
|
|
117
|
-
Transparency for ridgeplot fill. If ``ridgeplot_alpha=0``, border is colored by model,
|
|
118
|
-
otherwise a `black` outline is used.
|
|
119
|
-
ridgeplot_overlap : float, default 2
|
|
120
|
-
Overlap height for ridgeplots.
|
|
121
|
-
ridgeplot_kind : string, optional
|
|
122
|
-
By default ("auto") continuous variables are plotted using KDEs and discrete ones using
|
|
123
|
-
histograms. To override this use "hist" to plot histograms and "density" for KDEs.
|
|
124
|
-
ridgeplot_truncate : bool, default True
|
|
125
|
-
Whether to truncate densities according to the value of ``hdi_prob``.
|
|
126
|
-
ridgeplot_quantiles : list, optional
|
|
127
|
-
Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
|
|
128
|
-
figsize : (float, float), optional
|
|
129
|
-
Figure size. If `None`, it will be defined automatically.
|
|
130
|
-
ax : axes, optional
|
|
131
|
-
:class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
|
|
132
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
133
|
-
Select plotting backend.
|
|
134
|
-
backend_config : dict, optional
|
|
135
|
-
Currently specifies the bounds to use for bokeh axes. Defaults to value set in ``rcParams``.
|
|
136
|
-
backend_kwargs : dict, optional
|
|
137
|
-
These are kwargs specific to the backend being used, passed to
|
|
138
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
139
|
-
For additional documentation check the plotting method of the backend.
|
|
140
|
-
show : bool, optional
|
|
141
|
-
Call backend show function.
|
|
142
|
-
|
|
143
|
-
Returns
|
|
144
|
-
-------
|
|
145
|
-
1D ndarray of matplotlib_axes or bokeh_figures
|
|
146
|
-
|
|
147
|
-
See Also
|
|
148
|
-
--------
|
|
149
|
-
plot_posterior : Plot Posterior densities in the style of John K. Kruschke's book.
|
|
150
|
-
plot_density : Generate KDE plots for continuous variables and histograms for discrete ones.
|
|
151
|
-
summary : Create a data frame with summary statistics.
|
|
152
|
-
|
|
153
|
-
Examples
|
|
154
|
-
--------
|
|
155
|
-
Forestplot
|
|
156
|
-
|
|
157
|
-
.. plot::
|
|
158
|
-
:context: close-figs
|
|
159
|
-
|
|
160
|
-
>>> import arviz as az
|
|
161
|
-
>>> non_centered_data = az.load_arviz_data('non_centered_eight')
|
|
162
|
-
>>> axes = az.plot_forest(non_centered_data,
|
|
163
|
-
>>> kind='forestplot',
|
|
164
|
-
>>> var_names=["^the"],
|
|
165
|
-
>>> filter_vars="regex",
|
|
166
|
-
>>> combined=True,
|
|
167
|
-
>>> figsize=(9, 7))
|
|
168
|
-
>>> axes[0].set_title('Estimated theta for 8 schools model')
|
|
169
|
-
|
|
170
|
-
Forestplot with multiple datasets
|
|
171
|
-
|
|
172
|
-
.. plot::
|
|
173
|
-
:context: close-figs
|
|
174
|
-
|
|
175
|
-
>>> centered_data = az.load_arviz_data('centered_eight')
|
|
176
|
-
>>> axes = az.plot_forest([non_centered_data, centered_data],
|
|
177
|
-
>>> model_names = ["non centered eight", "centered eight"],
|
|
178
|
-
>>> kind='forestplot',
|
|
179
|
-
>>> var_names=["^the"],
|
|
180
|
-
>>> filter_vars="regex",
|
|
181
|
-
>>> combined=True,
|
|
182
|
-
>>> figsize=(9, 7))
|
|
183
|
-
>>> axes[0].set_title('Estimated theta for 8 schools models')
|
|
184
|
-
|
|
185
|
-
Ridgeplot
|
|
186
|
-
|
|
187
|
-
.. plot::
|
|
188
|
-
:context: close-figs
|
|
189
|
-
|
|
190
|
-
>>> axes = az.plot_forest(non_centered_data,
|
|
191
|
-
>>> kind='ridgeplot',
|
|
192
|
-
>>> var_names=['theta'],
|
|
193
|
-
>>> combined=True,
|
|
194
|
-
>>> ridgeplot_overlap=3,
|
|
195
|
-
>>> colors='white',
|
|
196
|
-
>>> figsize=(9, 7))
|
|
197
|
-
>>> axes[0].set_title('Estimated theta for 8 schools model')
|
|
198
|
-
|
|
199
|
-
Ridgeplot non-truncated and with quantiles
|
|
200
|
-
|
|
201
|
-
.. plot::
|
|
202
|
-
:context: close-figs
|
|
203
|
-
|
|
204
|
-
>>> axes = az.plot_forest(non_centered_data,
|
|
205
|
-
>>> kind='ridgeplot',
|
|
206
|
-
>>> var_names=['theta'],
|
|
207
|
-
>>> combined=True,
|
|
208
|
-
>>> ridgeplot_truncate=False,
|
|
209
|
-
>>> ridgeplot_quantiles=[.25, .5, .75],
|
|
210
|
-
>>> ridgeplot_overlap=0.7,
|
|
211
|
-
>>> colors='white',
|
|
212
|
-
>>> figsize=(9, 7))
|
|
213
|
-
>>> axes[0].set_title('Estimated theta for 8 schools model')
|
|
214
|
-
"""
|
|
215
|
-
if not isinstance(data, (list, tuple)):
|
|
216
|
-
data = [data]
|
|
217
|
-
if len(data) == 1:
|
|
218
|
-
legend = False
|
|
219
|
-
|
|
220
|
-
if coords is None:
|
|
221
|
-
coords = {}
|
|
222
|
-
|
|
223
|
-
if combine_dims is None:
|
|
224
|
-
combine_dims = set()
|
|
225
|
-
|
|
226
|
-
if labeller is None:
|
|
227
|
-
labeller = NoModelLabeller() if legend else BaseLabeller()
|
|
228
|
-
|
|
229
|
-
datasets = [convert_to_dataset(datum) for datum in reversed(data)]
|
|
230
|
-
if transform is not None:
|
|
231
|
-
if callable(transform):
|
|
232
|
-
datasets = [transform(dataset) for dataset in datasets]
|
|
233
|
-
elif isinstance(transform, dict):
|
|
234
|
-
transformed_datasets = []
|
|
235
|
-
for dataset in datasets:
|
|
236
|
-
new_dataset = dataset.copy()
|
|
237
|
-
for var_name, func in transform.items():
|
|
238
|
-
if var_name in new_dataset:
|
|
239
|
-
new_dataset[var_name] = func(new_dataset[var_name])
|
|
240
|
-
transformed_datasets.append(new_dataset)
|
|
241
|
-
datasets = transformed_datasets
|
|
242
|
-
else:
|
|
243
|
-
raise ValueError("transform must be either a callable or a dict {var_name: callable}")
|
|
244
|
-
datasets = get_coords(
|
|
245
|
-
datasets, list(reversed(coords)) if isinstance(coords, (list, tuple)) else coords
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
var_names = _var_names(var_names, datasets, filter_vars)
|
|
249
|
-
|
|
250
|
-
ncols, width_ratios = 1, [3]
|
|
251
|
-
|
|
252
|
-
if ess:
|
|
253
|
-
ncols += 1
|
|
254
|
-
width_ratios.append(1)
|
|
255
|
-
|
|
256
|
-
if r_hat:
|
|
257
|
-
ncols += 1
|
|
258
|
-
width_ratios.append(1)
|
|
259
|
-
|
|
260
|
-
if hdi_prob is None:
|
|
261
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
262
|
-
elif not 1 >= hdi_prob > 0:
|
|
263
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
264
|
-
|
|
265
|
-
plot_forest_kwargs = dict(
|
|
266
|
-
ax=ax,
|
|
267
|
-
datasets=datasets,
|
|
268
|
-
var_names=var_names,
|
|
269
|
-
model_names=model_names,
|
|
270
|
-
combined=combined,
|
|
271
|
-
combine_dims=combine_dims,
|
|
272
|
-
colors=colors,
|
|
273
|
-
figsize=figsize,
|
|
274
|
-
width_ratios=width_ratios,
|
|
275
|
-
linewidth=linewidth,
|
|
276
|
-
markersize=markersize,
|
|
277
|
-
kind=kind,
|
|
278
|
-
ncols=ncols,
|
|
279
|
-
hdi_prob=hdi_prob,
|
|
280
|
-
quartiles=quartiles,
|
|
281
|
-
rope=rope,
|
|
282
|
-
ridgeplot_overlap=ridgeplot_overlap,
|
|
283
|
-
ridgeplot_alpha=ridgeplot_alpha,
|
|
284
|
-
ridgeplot_kind=ridgeplot_kind,
|
|
285
|
-
ridgeplot_truncate=ridgeplot_truncate,
|
|
286
|
-
ridgeplot_quantiles=ridgeplot_quantiles,
|
|
287
|
-
textsize=textsize,
|
|
288
|
-
legend=legend,
|
|
289
|
-
labeller=labeller,
|
|
290
|
-
ess=ess,
|
|
291
|
-
r_hat=r_hat,
|
|
292
|
-
backend_kwargs=backend_kwargs,
|
|
293
|
-
backend_config=backend_config,
|
|
294
|
-
show=show,
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
if backend is None:
|
|
298
|
-
backend = rcParams["plot.backend"]
|
|
299
|
-
backend = backend.lower()
|
|
300
|
-
|
|
301
|
-
# TODO: Add backend kwargs
|
|
302
|
-
plot = get_plotting_function("plot_forest", "forestplot", backend)
|
|
303
|
-
axes = plot(**plot_forest_kwargs)
|
|
304
|
-
return axes
|