arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/rankplot.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
|
1
|
-
"""Histograms of ranked posterior draws, plotted for each chain."""
|
|
2
|
-
|
|
3
|
-
from itertools import cycle
|
|
4
|
-
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
|
|
7
|
-
from ..data import convert_to_dataset
|
|
8
|
-
from ..labels import BaseLabeller
|
|
9
|
-
from ..sel_utils import xarray_var_iter
|
|
10
|
-
from ..rcparams import rcParams
|
|
11
|
-
from ..stats.density_utils import _sturges_formula
|
|
12
|
-
from ..utils import _var_names
|
|
13
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def plot_rank(
|
|
17
|
-
data,
|
|
18
|
-
var_names=None,
|
|
19
|
-
filter_vars=None,
|
|
20
|
-
transform=None,
|
|
21
|
-
coords=None,
|
|
22
|
-
bins=None,
|
|
23
|
-
kind="bars",
|
|
24
|
-
colors="cycle",
|
|
25
|
-
ref_line=True,
|
|
26
|
-
labels=True,
|
|
27
|
-
labeller=None,
|
|
28
|
-
grid=None,
|
|
29
|
-
figsize=None,
|
|
30
|
-
ax=None,
|
|
31
|
-
backend=None,
|
|
32
|
-
ref_line_kwargs=None,
|
|
33
|
-
bar_kwargs=None,
|
|
34
|
-
vlines_kwargs=None,
|
|
35
|
-
marker_vlines_kwargs=None,
|
|
36
|
-
backend_kwargs=None,
|
|
37
|
-
show=None,
|
|
38
|
-
):
|
|
39
|
-
"""Plot rank order statistics of chains.
|
|
40
|
-
|
|
41
|
-
From the paper: Rank plots are histograms of the ranked posterior draws (ranked over all
|
|
42
|
-
chains) plotted separately for each chain.
|
|
43
|
-
If all of the chains are targeting the same posterior, we expect the ranks in each chain to be
|
|
44
|
-
uniform, whereas if one chain has a different location or scale parameter, this will be
|
|
45
|
-
reflected in the deviation from uniformity. If rank plots of all chains look similar, this
|
|
46
|
-
indicates good mixing of the chains.
|
|
47
|
-
|
|
48
|
-
This plot was introduced by Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter,
|
|
49
|
-
Paul-Christian Burkner (2021): Rank-normalization, folding, and localization:
|
|
50
|
-
An improved R-hat for assessing convergence of MCMC. Bayesian analysis, 16(2):667-718.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
Parameters
|
|
54
|
-
----------
|
|
55
|
-
data: obj
|
|
56
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
57
|
-
Refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
58
|
-
var_names: string or list of variable names
|
|
59
|
-
Variables to be plotted. Prefix the variables by ``~`` when you want to exclude
|
|
60
|
-
them from the plot.
|
|
61
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
62
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
63
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
64
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
65
|
-
``pandas.filter``.
|
|
66
|
-
transform: callable
|
|
67
|
-
Function to transform data (defaults to None i.e.the identity function)
|
|
68
|
-
coords: mapping, optional
|
|
69
|
-
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
70
|
-
bins: None or passed to np.histogram
|
|
71
|
-
Binning strategy used for histogram. By default uses twice the result of Sturges' formula.
|
|
72
|
-
See :func:`numpy.histogram` documentation for, other available arguments.
|
|
73
|
-
kind: string
|
|
74
|
-
If bars (defaults), ranks are represented as stacked histograms (one per chain). If vlines
|
|
75
|
-
ranks are represented as vertical lines above or below ``ref_line``.
|
|
76
|
-
colors: string or list of strings
|
|
77
|
-
List with valid matplotlib colors, one color per model. Alternative a string can be passed.
|
|
78
|
-
If the string is `cycle`, it will automatically choose a color per model from matplotlib's
|
|
79
|
-
cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all
|
|
80
|
-
models. Defaults to `cycle`.
|
|
81
|
-
ref_line: boolean
|
|
82
|
-
Whether to include a dashed line showing where a uniform distribution would lie
|
|
83
|
-
labels: bool
|
|
84
|
-
whether to plot or not the x and y labels, defaults to True
|
|
85
|
-
labeller : labeller instance, optional
|
|
86
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
87
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
88
|
-
grid : tuple
|
|
89
|
-
Number of rows and columns. Defaults to None, the rows and columns are
|
|
90
|
-
automatically inferred.
|
|
91
|
-
figsize: tuple
|
|
92
|
-
Figure size. If None it will be defined automatically.
|
|
93
|
-
ax: numpy array-like of matplotlib axes or bokeh figures, optional
|
|
94
|
-
A 2D array of locations into which to plot the densities. If not supplied, ArviZ will create
|
|
95
|
-
its own array of plot areas (and return it).
|
|
96
|
-
backend: str, optional
|
|
97
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
98
|
-
ref_line_kwargs : dict, optional
|
|
99
|
-
Reference line keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.axhline` or
|
|
100
|
-
:class:`bokeh:bokeh.models.Span`.
|
|
101
|
-
bar_kwargs : dict, optional
|
|
102
|
-
Bars keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.bar` or
|
|
103
|
-
:meth:`bokeh:bokeh.plotting.Figure.vbar`.
|
|
104
|
-
vlines_kwargs : dict, optional
|
|
105
|
-
Vlines keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.vlines` or
|
|
106
|
-
:meth:`bokeh:bokeh.plotting.Figure.multi_line`.
|
|
107
|
-
marker_vlines_kwargs : dict, optional
|
|
108
|
-
Marker for the vlines keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or
|
|
109
|
-
:meth:`bokeh:bokeh.plotting.Figure.circle`.
|
|
110
|
-
backend_kwargs: bool, optional
|
|
111
|
-
These are kwargs specific to the backend being used, passed to
|
|
112
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
113
|
-
:func:`bokeh.plotting.figure`. For additional documentation
|
|
114
|
-
check the plotting method of the backend.
|
|
115
|
-
show: bool, optional
|
|
116
|
-
Call backend show function.
|
|
117
|
-
|
|
118
|
-
Returns
|
|
119
|
-
-------
|
|
120
|
-
axes: matplotlib axes or bokeh figures
|
|
121
|
-
|
|
122
|
-
See Also
|
|
123
|
-
--------
|
|
124
|
-
plot_trace : Plot distribution (histogram or kernel density estimates) and
|
|
125
|
-
sampled values or rank plot.
|
|
126
|
-
|
|
127
|
-
Examples
|
|
128
|
-
--------
|
|
129
|
-
Show a default rank plot
|
|
130
|
-
|
|
131
|
-
.. plot::
|
|
132
|
-
:context: close-figs
|
|
133
|
-
|
|
134
|
-
>>> import arviz as az
|
|
135
|
-
>>> data = az.load_arviz_data('centered_eight')
|
|
136
|
-
>>> az.plot_rank(data)
|
|
137
|
-
|
|
138
|
-
Recreate Figure 13 from the arxiv preprint
|
|
139
|
-
|
|
140
|
-
.. plot::
|
|
141
|
-
:context: close-figs
|
|
142
|
-
|
|
143
|
-
>>> data = az.load_arviz_data('centered_eight')
|
|
144
|
-
>>> az.plot_rank(data, var_names='tau')
|
|
145
|
-
|
|
146
|
-
Use vlines to compare results for centered vs noncentered models
|
|
147
|
-
|
|
148
|
-
.. plot::
|
|
149
|
-
:context: close-figs
|
|
150
|
-
|
|
151
|
-
>>> import matplotlib.pyplot as plt
|
|
152
|
-
>>> centered_data = az.load_arviz_data('centered_eight')
|
|
153
|
-
>>> noncentered_data = az.load_arviz_data('non_centered_eight')
|
|
154
|
-
>>> _, ax = plt.subplots(1, 2, figsize=(12, 3))
|
|
155
|
-
>>> az.plot_rank(centered_data, var_names="mu", kind='vlines', ax=ax[0])
|
|
156
|
-
>>> az.plot_rank(noncentered_data, var_names="mu", kind='vlines', ax=ax[1])
|
|
157
|
-
|
|
158
|
-
Change the aesthetics using kwargs
|
|
159
|
-
|
|
160
|
-
.. plot::
|
|
161
|
-
:context: close-figs
|
|
162
|
-
|
|
163
|
-
>>> az.plot_rank(noncentered_data, var_names="mu", kind="vlines",
|
|
164
|
-
>>> vlines_kwargs={'lw':0}, marker_vlines_kwargs={'lw':3});
|
|
165
|
-
"""
|
|
166
|
-
if transform is not None:
|
|
167
|
-
data = transform(data)
|
|
168
|
-
posterior_data = convert_to_dataset(data, group="posterior")
|
|
169
|
-
if coords is not None:
|
|
170
|
-
posterior_data = posterior_data.sel(**coords)
|
|
171
|
-
var_names = _var_names(var_names, posterior_data, filter_vars)
|
|
172
|
-
plotters = filter_plotters_list(
|
|
173
|
-
list(
|
|
174
|
-
xarray_var_iter(
|
|
175
|
-
posterior_data,
|
|
176
|
-
var_names=var_names,
|
|
177
|
-
combined=True,
|
|
178
|
-
dim_order=["chain", "draw"],
|
|
179
|
-
)
|
|
180
|
-
),
|
|
181
|
-
"plot_rank",
|
|
182
|
-
)
|
|
183
|
-
length_plotters = len(plotters)
|
|
184
|
-
|
|
185
|
-
if bins is None:
|
|
186
|
-
bins = _sturges_formula(posterior_data, mult=2)
|
|
187
|
-
|
|
188
|
-
if labeller is None:
|
|
189
|
-
labeller = BaseLabeller()
|
|
190
|
-
|
|
191
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
192
|
-
|
|
193
|
-
chains = len(posterior_data.chain)
|
|
194
|
-
if colors == "cycle":
|
|
195
|
-
colors = [
|
|
196
|
-
prop
|
|
197
|
-
for _, prop in zip(
|
|
198
|
-
range(chains), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
|
199
|
-
)
|
|
200
|
-
]
|
|
201
|
-
elif isinstance(colors, str):
|
|
202
|
-
colors = [colors] * chains
|
|
203
|
-
|
|
204
|
-
rankplot_kwargs = dict(
|
|
205
|
-
axes=ax,
|
|
206
|
-
length_plotters=length_plotters,
|
|
207
|
-
rows=rows,
|
|
208
|
-
cols=cols,
|
|
209
|
-
figsize=figsize,
|
|
210
|
-
plotters=plotters,
|
|
211
|
-
bins=bins,
|
|
212
|
-
kind=kind,
|
|
213
|
-
colors=colors,
|
|
214
|
-
ref_line=ref_line,
|
|
215
|
-
labels=labels,
|
|
216
|
-
labeller=labeller,
|
|
217
|
-
ref_line_kwargs=ref_line_kwargs,
|
|
218
|
-
bar_kwargs=bar_kwargs,
|
|
219
|
-
vlines_kwargs=vlines_kwargs,
|
|
220
|
-
marker_vlines_kwargs=marker_vlines_kwargs,
|
|
221
|
-
backend_kwargs=backend_kwargs,
|
|
222
|
-
show=show,
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
if backend is None:
|
|
226
|
-
backend = rcParams["plot.backend"]
|
|
227
|
-
backend = backend.lower()
|
|
228
|
-
|
|
229
|
-
# TODO: Add backend kwargs
|
|
230
|
-
plot = get_plotting_function("plot_rank", "rankplot", backend)
|
|
231
|
-
axes = plot(**rankplot_kwargs)
|
|
232
|
-
return axes
|
arviz/plots/separationplot.py
DELETED
|
@@ -1,167 +0,0 @@
|
|
|
1
|
-
"""Separation plot for discrete outcome models."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import xarray as xr
|
|
7
|
-
|
|
8
|
-
from ..data import InferenceData
|
|
9
|
-
from ..rcparams import rcParams
|
|
10
|
-
from .plot_utils import get_plotting_function
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
def plot_separation(
|
|
14
|
-
idata=None,
|
|
15
|
-
y=None,
|
|
16
|
-
y_hat=None,
|
|
17
|
-
y_hat_line=False,
|
|
18
|
-
expected_events=False,
|
|
19
|
-
figsize=None,
|
|
20
|
-
textsize=None,
|
|
21
|
-
color="C0",
|
|
22
|
-
legend=True,
|
|
23
|
-
ax=None,
|
|
24
|
-
plot_kwargs=None,
|
|
25
|
-
y_hat_line_kwargs=None,
|
|
26
|
-
exp_events_kwargs=None,
|
|
27
|
-
backend=None,
|
|
28
|
-
backend_kwargs=None,
|
|
29
|
-
show=None,
|
|
30
|
-
):
|
|
31
|
-
"""Separation plot for binary outcome models.
|
|
32
|
-
|
|
33
|
-
Model predictions are sorted and plotted using a color code according to
|
|
34
|
-
the observed data.
|
|
35
|
-
|
|
36
|
-
Parameters
|
|
37
|
-
----------
|
|
38
|
-
idata : InferenceData
|
|
39
|
-
:class:`arviz.InferenceData` object.
|
|
40
|
-
y : array, DataArray or str
|
|
41
|
-
Observed data. If str, ``idata`` must be present and contain the observed data group
|
|
42
|
-
y_hat : array, DataArray or str
|
|
43
|
-
Posterior predictive samples for ``y``. It must have the same shape as ``y``. If str or
|
|
44
|
-
None, ``idata`` must contain the posterior predictive group.
|
|
45
|
-
y_hat_line : bool, optional
|
|
46
|
-
Plot the sorted ``y_hat`` predictions.
|
|
47
|
-
expected_events : bool, optional
|
|
48
|
-
Plot the total number of expected events.
|
|
49
|
-
figsize : figure size tuple, optional
|
|
50
|
-
If None, size is (8 + numvars, 8 + numvars)
|
|
51
|
-
textsize: int, optional
|
|
52
|
-
Text size for labels. If None it will be autoscaled based on ``figsize``.
|
|
53
|
-
color : str, optional
|
|
54
|
-
Color to assign to the positive class. The negative class will be plotted using the
|
|
55
|
-
same color and an `alpha=0.3` transparency.
|
|
56
|
-
legend : bool, optional
|
|
57
|
-
Show the legend of the figure.
|
|
58
|
-
ax: axes, optional
|
|
59
|
-
Matplotlib axes or bokeh figures.
|
|
60
|
-
plot_kwargs : dict, optional
|
|
61
|
-
Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or
|
|
62
|
-
:meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot.
|
|
63
|
-
y_hat_line_kwargs : dict, optional
|
|
64
|
-
Additional keywords passed to ax.plot for ``y_hat`` line.
|
|
65
|
-
exp_events_kwargs : dict, optional
|
|
66
|
-
Additional keywords passed to ax.scatter for ``expected_events`` marker.
|
|
67
|
-
backend: str, optional
|
|
68
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
69
|
-
backend_kwargs: bool, optional
|
|
70
|
-
These are kwargs specific to the backend being used, passed to
|
|
71
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
72
|
-
:func:`bokeh.plotting.figure`.
|
|
73
|
-
show : bool, optional
|
|
74
|
-
Call backend show function.
|
|
75
|
-
|
|
76
|
-
Returns
|
|
77
|
-
-------
|
|
78
|
-
axes : matplotlib axes or bokeh figures
|
|
79
|
-
|
|
80
|
-
See Also
|
|
81
|
-
--------
|
|
82
|
-
plot_ppc : Plot for posterior/prior predictive checks.
|
|
83
|
-
|
|
84
|
-
References
|
|
85
|
-
----------
|
|
86
|
-
.. [1] Greenhill, B. *et al.*, The Separation Plot: A New Visual Method
|
|
87
|
-
for Evaluating the Fit of Binary Models, *American Journal of
|
|
88
|
-
Political Science*, (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x
|
|
89
|
-
|
|
90
|
-
Examples
|
|
91
|
-
--------
|
|
92
|
-
Separation plot for a logistic regression model.
|
|
93
|
-
|
|
94
|
-
.. plot::
|
|
95
|
-
:context: close-figs
|
|
96
|
-
|
|
97
|
-
>>> import arviz as az
|
|
98
|
-
>>> idata = az.load_arviz_data('classification10d')
|
|
99
|
-
>>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1))
|
|
100
|
-
|
|
101
|
-
"""
|
|
102
|
-
label_y_hat = "y_hat"
|
|
103
|
-
if idata is not None and not isinstance(idata, InferenceData):
|
|
104
|
-
raise ValueError("idata must be of type InferenceData or None")
|
|
105
|
-
|
|
106
|
-
if idata is None:
|
|
107
|
-
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)):
|
|
108
|
-
raise ValueError(
|
|
109
|
-
"y and y_hat must be array or DataArray when idata is None "
|
|
110
|
-
f"but they are of types {[type(arg) for arg in (y, y_hat)]}"
|
|
111
|
-
)
|
|
112
|
-
|
|
113
|
-
else:
|
|
114
|
-
if y_hat is None and isinstance(y, str):
|
|
115
|
-
label_y_hat = y
|
|
116
|
-
y_hat = y
|
|
117
|
-
elif y_hat is None:
|
|
118
|
-
raise ValueError("y_hat cannot be None if y is not a str")
|
|
119
|
-
|
|
120
|
-
if isinstance(y, str):
|
|
121
|
-
y = idata.observed_data[y].values
|
|
122
|
-
elif not isinstance(y, (np.ndarray, xr.DataArray)):
|
|
123
|
-
raise ValueError(f"y must be of types array, DataArray or str, not {type(y)}")
|
|
124
|
-
|
|
125
|
-
if isinstance(y_hat, str):
|
|
126
|
-
label_y_hat = y_hat
|
|
127
|
-
y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values
|
|
128
|
-
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
|
|
129
|
-
raise ValueError(f"y_hat must be of types array, DataArray or str, not {type(y_hat)}")
|
|
130
|
-
|
|
131
|
-
if len(y) != len(y_hat):
|
|
132
|
-
warnings.warn(
|
|
133
|
-
"y and y_hat must be the same length",
|
|
134
|
-
UserWarning,
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
locs = np.linspace(0, 1, len(y_hat))
|
|
138
|
-
width = np.diff(locs).mean()
|
|
139
|
-
|
|
140
|
-
separation_kwargs = dict(
|
|
141
|
-
y=y,
|
|
142
|
-
y_hat=y_hat,
|
|
143
|
-
y_hat_line=y_hat_line,
|
|
144
|
-
label_y_hat=label_y_hat,
|
|
145
|
-
expected_events=expected_events,
|
|
146
|
-
figsize=figsize,
|
|
147
|
-
textsize=textsize,
|
|
148
|
-
color=color,
|
|
149
|
-
legend=legend,
|
|
150
|
-
locs=locs,
|
|
151
|
-
width=width,
|
|
152
|
-
ax=ax,
|
|
153
|
-
plot_kwargs=plot_kwargs,
|
|
154
|
-
y_hat_line_kwargs=y_hat_line_kwargs,
|
|
155
|
-
exp_events_kwargs=exp_events_kwargs,
|
|
156
|
-
backend_kwargs=backend_kwargs,
|
|
157
|
-
show=show,
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
if backend is None:
|
|
161
|
-
backend = rcParams["plot.backend"]
|
|
162
|
-
backend = backend.lower()
|
|
163
|
-
|
|
164
|
-
plot = get_plotting_function("plot_separation", "separationplot", backend)
|
|
165
|
-
axes = plot(**separation_kwargs)
|
|
166
|
-
|
|
167
|
-
return axes
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['0045b9ff', '0045b999', '0045b966', '0045b933', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['bb5100ff', 'ec7a00ff', 'ec7a00aa', 'ec7a0066', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['008182ff', '008182aa', '00818266', '00818233', '000000'])
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
# This style is based on Seaborn-darkgrid parameters, the main differences are
|
|
2
|
-
# Default matplotlib font (no problem using Greek letters!)
|
|
3
|
-
# Larger font size for several elements
|
|
4
|
-
# Colorblind friendly color cycle
|
|
5
|
-
figure.figsize: 7.2, 4.8
|
|
6
|
-
figure.dpi: 100.0
|
|
7
|
-
figure.facecolor: white
|
|
8
|
-
figure.constrained_layout.use: True
|
|
9
|
-
text.color: .15
|
|
10
|
-
axes.labelcolor: .15
|
|
11
|
-
legend.frameon: False
|
|
12
|
-
legend.numpoints: 1
|
|
13
|
-
legend.scatterpoints: 1
|
|
14
|
-
xtick.direction: out
|
|
15
|
-
ytick.direction: out
|
|
16
|
-
xtick.color: .15
|
|
17
|
-
ytick.color: .15
|
|
18
|
-
axes.axisbelow: True
|
|
19
|
-
grid.linestyle: -
|
|
20
|
-
lines.solid_capstyle: round
|
|
21
|
-
|
|
22
|
-
axes.labelsize: 15
|
|
23
|
-
axes.titlesize: 16
|
|
24
|
-
xtick.labelsize: 14
|
|
25
|
-
ytick.labelsize: 14
|
|
26
|
-
legend.fontsize: 14
|
|
27
|
-
|
|
28
|
-
axes.grid: True
|
|
29
|
-
axes.facecolor: eeeeee
|
|
30
|
-
axes.edgecolor: white
|
|
31
|
-
axes.linewidth: 0
|
|
32
|
-
grid.color: white
|
|
33
|
-
image.cmap: viridis
|
|
34
|
-
xtick.major.size: 0
|
|
35
|
-
ytick.major.size: 0
|
|
36
|
-
xtick.minor.size: 0
|
|
37
|
-
ytick.minor.size: 0
|
|
38
|
-
|
|
39
|
-
# color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
|
|
40
|
-
axes.prop_cycle: cycler('color', ['2a2eec', 'fa7c17', '328c06', 'c10c90', '933708', '65e5f3', 'e6e135', '1ccd6a', 'bd8ad5', 'b16b57'])
|
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
## ***************************************************************************
|
|
2
|
-
## * FIGURE *
|
|
3
|
-
## ***************************************************************************
|
|
4
|
-
|
|
5
|
-
figure.facecolor: white # broken white outside box
|
|
6
|
-
figure.edgecolor: None # broken white outside box
|
|
7
|
-
figure.titleweight: bold # weight of the figure title
|
|
8
|
-
figure.titlesize: 18
|
|
9
|
-
|
|
10
|
-
figure.figsize: 11.5, 5
|
|
11
|
-
figure.dpi: 300.0
|
|
12
|
-
figure.constrained_layout.use: True
|
|
13
|
-
|
|
14
|
-
## ***************************************************************************
|
|
15
|
-
## * FONT *
|
|
16
|
-
## ***************************************************************************
|
|
17
|
-
|
|
18
|
-
font.style: normal
|
|
19
|
-
font.variant: normal
|
|
20
|
-
font.weight: normal
|
|
21
|
-
font.stretch: normal
|
|
22
|
-
|
|
23
|
-
text.color: .15
|
|
24
|
-
|
|
25
|
-
## ***************************************************************************
|
|
26
|
-
## * AXES *
|
|
27
|
-
## ***************************************************************************
|
|
28
|
-
|
|
29
|
-
axes.facecolor: white
|
|
30
|
-
axes.edgecolor: .33 # axes edge color
|
|
31
|
-
axes.linewidth: 0.8 # edge line width
|
|
32
|
-
|
|
33
|
-
axes.grid: False # do not show grid
|
|
34
|
-
# axes.grid.axis: y # which axis the grid should apply to
|
|
35
|
-
axes.grid.which: major # grid lines at {major, minor, both} ticks
|
|
36
|
-
axes.axisbelow: True # keep grid layer in the back
|
|
37
|
-
|
|
38
|
-
grid.color: .8 # grid color
|
|
39
|
-
grid.linestyle: - # solid
|
|
40
|
-
grid.linewidth: 0.8 # in points
|
|
41
|
-
grid.alpha: 1.0 # transparency, between 0.0 and 1.0
|
|
42
|
-
|
|
43
|
-
lines.solid_capstyle: round
|
|
44
|
-
|
|
45
|
-
axes.spines.right: False # do not show right spine
|
|
46
|
-
axes.spines.top: False # do not show top spine
|
|
47
|
-
|
|
48
|
-
axes.titlesize: 16
|
|
49
|
-
axes.titleweight: bold # font weight of title
|
|
50
|
-
|
|
51
|
-
axes.labelsize: 14
|
|
52
|
-
axes.labelcolor: .15
|
|
53
|
-
axes.labelweight: normal # weight of the x and y labels
|
|
54
|
-
|
|
55
|
-
# color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
|
|
56
|
-
# see preview and check for colorblindness here https://coolors.co/107591-00c0bf-f69a48-fdcd49-8da798-a19368-525252-a6761d-7035b7-cf166e
|
|
57
|
-
axes.prop_cycle: cycler(color=['107591','00c0bf','f69a48','fdcd49','8da798','a19368','525252','a6761d','7035b7','cf166e'])
|
|
58
|
-
|
|
59
|
-
image.cmap: viridis
|
|
60
|
-
|
|
61
|
-
## ***************************************************************************
|
|
62
|
-
## * TICKS *
|
|
63
|
-
## ***************************************************************************
|
|
64
|
-
|
|
65
|
-
xtick.labelsize: 14
|
|
66
|
-
xtick.color: .15
|
|
67
|
-
xtick.top: False
|
|
68
|
-
xtick.bottom: True
|
|
69
|
-
xtick.direction: out
|
|
70
|
-
|
|
71
|
-
ytick.labelsize: 14
|
|
72
|
-
ytick.color: .15
|
|
73
|
-
ytick.left: True
|
|
74
|
-
ytick.right: False
|
|
75
|
-
ytick.direction: out
|
|
76
|
-
|
|
77
|
-
## ***************************************************************************
|
|
78
|
-
## * LEGEND *
|
|
79
|
-
## ***************************************************************************
|
|
80
|
-
|
|
81
|
-
legend.framealpha: 0.5
|
|
82
|
-
legend.frameon: False # do not draw on background patch
|
|
83
|
-
legend.fancybox: False # do not round corners
|
|
84
|
-
|
|
85
|
-
legend.numpoints: 1
|
|
86
|
-
legend.scatterpoints: 1
|
|
87
|
-
|
|
88
|
-
legend.fontsize: 14
|
|
@@ -1,88 +0,0 @@
|
|
|
1
|
-
## ***************************************************************************
|
|
2
|
-
## * FIGURE *
|
|
3
|
-
## ***************************************************************************
|
|
4
|
-
|
|
5
|
-
figure.facecolor: white # broken white outside box
|
|
6
|
-
figure.edgecolor: None # broken white outside box
|
|
7
|
-
figure.titleweight: bold # weight of the figure title
|
|
8
|
-
figure.titlesize: 18
|
|
9
|
-
|
|
10
|
-
figure.figsize: 11.5, 5
|
|
11
|
-
figure.dpi: 300.0
|
|
12
|
-
figure.constrained_layout.use: True
|
|
13
|
-
|
|
14
|
-
## ***************************************************************************
|
|
15
|
-
## * FONT *
|
|
16
|
-
## ***************************************************************************
|
|
17
|
-
|
|
18
|
-
font.style: normal
|
|
19
|
-
font.variant: normal
|
|
20
|
-
font.weight: normal
|
|
21
|
-
font.stretch: normal
|
|
22
|
-
|
|
23
|
-
text.color: .15
|
|
24
|
-
|
|
25
|
-
## ***************************************************************************
|
|
26
|
-
## * AXES *
|
|
27
|
-
## ***************************************************************************
|
|
28
|
-
|
|
29
|
-
axes.facecolor: white
|
|
30
|
-
axes.edgecolor: .33 # axes edge color
|
|
31
|
-
axes.linewidth: 0.8 # edge line width
|
|
32
|
-
|
|
33
|
-
axes.grid: True # show grid
|
|
34
|
-
# axes.grid.axis: y # which axis the grid should apply to
|
|
35
|
-
axes.grid.which: major # grid lines at {major, minor, both} ticks
|
|
36
|
-
axes.axisbelow: True # keep grid layer in the back
|
|
37
|
-
|
|
38
|
-
grid.color: .8 # grid color
|
|
39
|
-
grid.linestyle: - # solid
|
|
40
|
-
grid.linewidth: 0.8 # in points
|
|
41
|
-
grid.alpha: 1.0 # transparency, between 0.0 and 1.0
|
|
42
|
-
|
|
43
|
-
lines.solid_capstyle: round
|
|
44
|
-
|
|
45
|
-
axes.spines.right: False # do not show right spine
|
|
46
|
-
axes.spines.top: False # do not show top spine
|
|
47
|
-
|
|
48
|
-
axes.titlesize: 16
|
|
49
|
-
axes.titleweight: bold # font weight of title
|
|
50
|
-
|
|
51
|
-
axes.labelsize: 14
|
|
52
|
-
axes.labelcolor: .15
|
|
53
|
-
axes.labelweight: normal # weight of the x and y labels
|
|
54
|
-
|
|
55
|
-
# color-blind friendly cycle designed using https://colorcyclepicker.mpetroff.net/
|
|
56
|
-
# see preview and check for colorblindness here https://coolors.co/107591-00c0bf-f69a48-fdcd49-8da798-a19368-525252-a6761d-7035b7-cf166e
|
|
57
|
-
axes.prop_cycle: cycler(color=['107591','00c0bf','f69a48','fdcd49','8da798','a19368','525252','a6761d','7035b7','cf166e'])
|
|
58
|
-
|
|
59
|
-
image.cmap: viridis
|
|
60
|
-
|
|
61
|
-
## ***************************************************************************
|
|
62
|
-
## * TICKS *
|
|
63
|
-
## ***************************************************************************
|
|
64
|
-
|
|
65
|
-
xtick.labelsize: 14
|
|
66
|
-
xtick.color: .15
|
|
67
|
-
xtick.top: False
|
|
68
|
-
xtick.bottom: True
|
|
69
|
-
xtick.direction: out
|
|
70
|
-
|
|
71
|
-
ytick.labelsize: 14
|
|
72
|
-
ytick.color: .15
|
|
73
|
-
ytick.left: True
|
|
74
|
-
ytick.right: False
|
|
75
|
-
ytick.direction: out
|
|
76
|
-
|
|
77
|
-
## ***************************************************************************
|
|
78
|
-
## * LEGEND *
|
|
79
|
-
## ***************************************************************************
|
|
80
|
-
|
|
81
|
-
legend.framealpha: 0.5
|
|
82
|
-
legend.frameon: False # do not draw on background patch
|
|
83
|
-
legend.fancybox: False # do not round corners
|
|
84
|
-
|
|
85
|
-
legend.numpoints: 1
|
|
86
|
-
legend.scatterpoints: 1
|
|
87
|
-
|
|
88
|
-
legend.fontsize: 14
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# This style is based on Seaborn-white parameters, the main differences are
|
|
2
|
-
# Default matplotlib font (no problem using Greek letters!)
|
|
3
|
-
# Larger font size for several elements
|
|
4
|
-
# Colorblind friendly color cycle
|
|
5
|
-
figure.figsize: 7.2, 4.8
|
|
6
|
-
figure.dpi: 100.0
|
|
7
|
-
figure.facecolor: white
|
|
8
|
-
figure.constrained_layout.use: True
|
|
9
|
-
text.color: .15
|
|
10
|
-
axes.labelcolor: .15
|
|
11
|
-
legend.frameon: False
|
|
12
|
-
legend.numpoints: 1
|
|
13
|
-
legend.scatterpoints: 1
|
|
14
|
-
xtick.direction: out
|
|
15
|
-
ytick.direction: out
|
|
16
|
-
xtick.color: .15
|
|
17
|
-
ytick.color: .15
|
|
18
|
-
axes.axisbelow: True
|
|
19
|
-
lines.solid_capstyle: round
|
|
20
|
-
|
|
21
|
-
axes.labelsize: 15
|
|
22
|
-
axes.titlesize: 16
|
|
23
|
-
xtick.labelsize: 14
|
|
24
|
-
ytick.labelsize: 14
|
|
25
|
-
legend.fontsize: 14
|
|
26
|
-
|
|
27
|
-
axes.grid: False
|
|
28
|
-
axes.facecolor: white
|
|
29
|
-
axes.edgecolor: 0
|
|
30
|
-
axes.linewidth: 1
|
|
31
|
-
axes.spines.top: False
|
|
32
|
-
axes.spines.right: False
|
|
33
|
-
image.cmap: cet_gray # perceptually uniform gray scale from colorcet (linear_grey_10_95_c0)
|
|
34
|
-
xtick.major.size: 0
|
|
35
|
-
ytick.major.size: 0
|
|
36
|
-
xtick.minor.size: 0
|
|
37
|
-
ytick.minor.size: 0
|
|
38
|
-
|
|
39
|
-
# First 4 colors are from colorcet and the last one is the "ArviZ-blue"
|
|
40
|
-
# [to_hex(_linear_grey_0_100_c0[i]) for i in np.linspace(0, 195, 4).astype(int)]
|
|
41
|
-
axes.prop_cycle: cycler('color', ["000000", "4a4a4a", "7e7f7f", "b8b8b8", "2a2eec"])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['259516ff', '259516aa', '25951666', '25951633', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['fd5800ff', 'fd5800aa', 'fd580066', 'fd580033', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['0d0887', '8e0ca4', 'de6164', 'fdc627', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['820076ff', '820076aa', '82007666', '82007633', '000000'])
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
axes.prop_cycle: cycler('color', ['bf0700ff', 'bf0700aa', 'bf070066', 'bf070033', '000000'])
|