arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/posteriorplot.py
DELETED
|
@@ -1,298 +0,0 @@
|
|
|
1
|
-
"""Plot posterior densities."""
|
|
2
|
-
|
|
3
|
-
from ..data import convert_to_dataset
|
|
4
|
-
from ..labels import BaseLabeller
|
|
5
|
-
from ..sel_utils import xarray_var_iter
|
|
6
|
-
from ..utils import _var_names, get_coords
|
|
7
|
-
from ..rcparams import rcParams
|
|
8
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def plot_posterior(
|
|
12
|
-
data,
|
|
13
|
-
var_names=None,
|
|
14
|
-
filter_vars=None,
|
|
15
|
-
combine_dims=None,
|
|
16
|
-
transform=None,
|
|
17
|
-
coords=None,
|
|
18
|
-
grid=None,
|
|
19
|
-
figsize=None,
|
|
20
|
-
textsize=None,
|
|
21
|
-
hdi_prob=None,
|
|
22
|
-
multimodal=False,
|
|
23
|
-
skipna=False,
|
|
24
|
-
round_to=None,
|
|
25
|
-
point_estimate="auto",
|
|
26
|
-
group="posterior",
|
|
27
|
-
rope=None,
|
|
28
|
-
ref_val=None,
|
|
29
|
-
rope_color="C2",
|
|
30
|
-
ref_val_color="C1",
|
|
31
|
-
kind=None,
|
|
32
|
-
bw="default",
|
|
33
|
-
circular=False,
|
|
34
|
-
bins=None,
|
|
35
|
-
labeller=None,
|
|
36
|
-
ax=None,
|
|
37
|
-
backend=None,
|
|
38
|
-
backend_kwargs=None,
|
|
39
|
-
show=None,
|
|
40
|
-
**kwargs
|
|
41
|
-
):
|
|
42
|
-
r"""Plot Posterior densities in the style of John K. Kruschke's book.
|
|
43
|
-
|
|
44
|
-
Parameters
|
|
45
|
-
----------
|
|
46
|
-
data: obj
|
|
47
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object.
|
|
48
|
-
Refer to the documentation of :func:`arviz.convert_to_dataset` for details
|
|
49
|
-
var_names: list of variable names
|
|
50
|
-
Variables to be plotted, two variables are required. Prefix the variables with ``~``
|
|
51
|
-
when you want to exclude them from the plot.
|
|
52
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
53
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
54
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
55
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
56
|
-
``pandas.filter``.
|
|
57
|
-
combine_dims : set_like of str, optional
|
|
58
|
-
List of dimensions to reduce. Defaults to reducing only the "chain" and "draw" dimensions.
|
|
59
|
-
See the :ref:`this section <common_combine_dims>` for usage examples.
|
|
60
|
-
transform: callable
|
|
61
|
-
Function to transform data (defaults to None i.e.the identity function)
|
|
62
|
-
coords: mapping, optional
|
|
63
|
-
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
64
|
-
grid : tuple
|
|
65
|
-
Number of rows and columns. Defaults to None, the rows and columns are
|
|
66
|
-
automatically inferred.
|
|
67
|
-
figsize: tuple
|
|
68
|
-
Figure size. If None it will be defined automatically.
|
|
69
|
-
textsize: float
|
|
70
|
-
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
|
|
71
|
-
on ``figsize``.
|
|
72
|
-
hdi_prob: float, optional
|
|
73
|
-
Plots highest density interval for chosen percentage of density.
|
|
74
|
-
Use 'hide' to hide the highest density interval. Defaults to 0.94.
|
|
75
|
-
multimodal: bool
|
|
76
|
-
If true (default) it may compute more than one credible interval if the distribution is
|
|
77
|
-
multimodal and the modes are well separated.
|
|
78
|
-
skipna : bool
|
|
79
|
-
If true ignores nan values when computing the hdi and point estimates. Defaults to false.
|
|
80
|
-
round_to: int, optional
|
|
81
|
-
Controls formatting of floats. Defaults to 2 or the integer part, whichever is bigger.
|
|
82
|
-
point_estimate: Optional[str]
|
|
83
|
-
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
|
|
84
|
-
Defaults to 'auto' i.e. it falls back to default set in rcParams.
|
|
85
|
-
group: str, optional
|
|
86
|
-
Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
|
|
87
|
-
rope : list, tuple or dictionary of {str: tuples or lists}, optional
|
|
88
|
-
A dictionary of tuples with the lower and upper values of the Region Of Practical
|
|
89
|
-
Equivalence. See :ref:`this section <common_rope>` for usage examples.
|
|
90
|
-
ref_val: float or dictionary of floats
|
|
91
|
-
display the percentage below and above the values in ref_val. Must be None (default),
|
|
92
|
-
a constant, a list or a dictionary like see an example below. If a list is provided, its
|
|
93
|
-
length should match the number of variables.
|
|
94
|
-
rope_color: str, optional
|
|
95
|
-
Specifies the color of ROPE and displayed percentage within ROPE
|
|
96
|
-
ref_val_color: str, optional
|
|
97
|
-
Specifies the color of the displayed percentage
|
|
98
|
-
kind: str
|
|
99
|
-
Type of plot to display (kde or hist) For discrete variables this argument is ignored and
|
|
100
|
-
a histogram is always used. Defaults to rcParam ``plot.density_kind``
|
|
101
|
-
bw: float or str, optional
|
|
102
|
-
If numeric, indicates the bandwidth and must be positive.
|
|
103
|
-
If str, indicates the method to estimate the bandwidth and must be
|
|
104
|
-
one of "scott", "silverman", "isj" or "experimental" when `circular` is False
|
|
105
|
-
and "taylor" (for now) when `circular` is True.
|
|
106
|
-
Defaults to "default" which means "experimental" when variable is not circular
|
|
107
|
-
and "taylor" when it is. Only works if `kind == kde`.
|
|
108
|
-
circular: bool, optional
|
|
109
|
-
If True, it interprets the values passed are from a circular variable measured in radians
|
|
110
|
-
and a circular KDE is used. Only valid for 1D KDE. Defaults to False.
|
|
111
|
-
Only works if `kind == kde`.
|
|
112
|
-
bins: integer or sequence or 'auto', optional
|
|
113
|
-
Controls the number of bins,accepts the same keywords :func:`matplotlib.pyplot.hist` does.
|
|
114
|
-
Only works if `kind == hist`. If None (default) it will use `auto` for continuous variables
|
|
115
|
-
and `range(xmin, xmax + 1)` for discrete variables.
|
|
116
|
-
labeller : labeller instance, optional
|
|
117
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot titles.
|
|
118
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
119
|
-
ax: numpy array-like of matplotlib axes or bokeh figures, optional
|
|
120
|
-
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
|
|
121
|
-
its own array of plot areas (and return it).
|
|
122
|
-
backend: str, optional
|
|
123
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
124
|
-
backend_kwargs: bool, optional
|
|
125
|
-
These are kwargs specific to the backend being used, passed to
|
|
126
|
-
:func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`
|
|
127
|
-
show: bool, optional
|
|
128
|
-
Call backend show function.
|
|
129
|
-
**kwargs
|
|
130
|
-
Passed as-is to :func:`matplotlib.pyplot.hist` or :func:`matplotlib.pyplot.plot` function
|
|
131
|
-
depending on the value of `kind`.
|
|
132
|
-
|
|
133
|
-
Returns
|
|
134
|
-
-------
|
|
135
|
-
axes: matplotlib axes or bokeh figures
|
|
136
|
-
|
|
137
|
-
See Also
|
|
138
|
-
--------
|
|
139
|
-
plot_dist : Plot distribution as histogram or kernel density estimates.
|
|
140
|
-
plot_density : Generate KDE plots for continuous variables and histograms for discrete ones.
|
|
141
|
-
plot_forest : Forest plot to compare HDI intervals from a number of distributions.
|
|
142
|
-
|
|
143
|
-
Examples
|
|
144
|
-
--------
|
|
145
|
-
Show a default kernel density plot following style of John Kruschke
|
|
146
|
-
|
|
147
|
-
.. plot::
|
|
148
|
-
:context: close-figs
|
|
149
|
-
|
|
150
|
-
>>> import arviz as az
|
|
151
|
-
>>> data = az.load_arviz_data('centered_eight')
|
|
152
|
-
>>> az.plot_posterior(data)
|
|
153
|
-
|
|
154
|
-
Plot subset variables by specifying variable name exactly
|
|
155
|
-
|
|
156
|
-
.. plot::
|
|
157
|
-
:context: close-figs
|
|
158
|
-
|
|
159
|
-
>>> az.plot_posterior(data, var_names=['mu'])
|
|
160
|
-
|
|
161
|
-
Plot Region of Practical Equivalence (rope) and select variables with regular expressions
|
|
162
|
-
|
|
163
|
-
.. plot::
|
|
164
|
-
:context: close-figs
|
|
165
|
-
|
|
166
|
-
>>> az.plot_posterior(data, var_names=['mu', '^the'], filter_vars="regex", rope=(-1, 1))
|
|
167
|
-
|
|
168
|
-
Plot Region of Practical Equivalence for selected distributions
|
|
169
|
-
|
|
170
|
-
.. plot::
|
|
171
|
-
:context: close-figs
|
|
172
|
-
|
|
173
|
-
>>> rope = {'mu': [{'rope': (-2, 2)}], 'theta': [{'school': 'Choate', 'rope': (2, 4)}]}
|
|
174
|
-
>>> az.plot_posterior(data, var_names=['mu', 'theta'], rope=rope)
|
|
175
|
-
|
|
176
|
-
Using `coords` argument to plot only a subset of data
|
|
177
|
-
|
|
178
|
-
.. plot::
|
|
179
|
-
:context: close-figs
|
|
180
|
-
|
|
181
|
-
>>> coords = {"school": ["Choate","Phillips Exeter"]}
|
|
182
|
-
>>> az.plot_posterior(data, var_names=["mu", "theta"], coords=coords)
|
|
183
|
-
|
|
184
|
-
Add reference lines
|
|
185
|
-
|
|
186
|
-
.. plot::
|
|
187
|
-
:context: close-figs
|
|
188
|
-
|
|
189
|
-
>>> az.plot_posterior(data, var_names=['mu', 'theta'], ref_val=0)
|
|
190
|
-
|
|
191
|
-
Show point estimate of distribution
|
|
192
|
-
|
|
193
|
-
.. plot::
|
|
194
|
-
:context: close-figs
|
|
195
|
-
|
|
196
|
-
>>> az.plot_posterior(data, var_names=['mu', 'theta'], point_estimate='mode')
|
|
197
|
-
|
|
198
|
-
Show reference values using variable names and coordinates
|
|
199
|
-
|
|
200
|
-
.. plot::
|
|
201
|
-
:context: close-figs
|
|
202
|
-
|
|
203
|
-
>>> az.plot_posterior(data, ref_val= {"theta": [{"school": "Deerfield", "ref_val": 4},
|
|
204
|
-
... {"school": "Choate", "ref_val": 3}]})
|
|
205
|
-
|
|
206
|
-
Show reference values using a list
|
|
207
|
-
|
|
208
|
-
.. plot::
|
|
209
|
-
:context: close-figs
|
|
210
|
-
|
|
211
|
-
>>> az.plot_posterior(data, ref_val=[1] + [5] * 8 + [1])
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
Plot posterior as a histogram
|
|
215
|
-
|
|
216
|
-
.. plot::
|
|
217
|
-
:context: close-figs
|
|
218
|
-
|
|
219
|
-
>>> az.plot_posterior(data, var_names=['mu'], kind='hist')
|
|
220
|
-
|
|
221
|
-
Change size of highest density interval
|
|
222
|
-
|
|
223
|
-
.. plot::
|
|
224
|
-
:context: close-figs
|
|
225
|
-
|
|
226
|
-
>>> az.plot_posterior(data, var_names=['mu'], hdi_prob=.75)
|
|
227
|
-
"""
|
|
228
|
-
data = convert_to_dataset(data, group=group)
|
|
229
|
-
if transform is not None:
|
|
230
|
-
data = transform(data)
|
|
231
|
-
var_names = _var_names(var_names, data, filter_vars)
|
|
232
|
-
|
|
233
|
-
if coords is None:
|
|
234
|
-
coords = {}
|
|
235
|
-
|
|
236
|
-
if labeller is None:
|
|
237
|
-
labeller = BaseLabeller()
|
|
238
|
-
|
|
239
|
-
if hdi_prob is None:
|
|
240
|
-
hdi_prob = rcParams["stats.ci_prob"]
|
|
241
|
-
elif hdi_prob not in (None, "hide"):
|
|
242
|
-
if not 1 >= hdi_prob > 0:
|
|
243
|
-
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
244
|
-
|
|
245
|
-
if point_estimate == "auto":
|
|
246
|
-
point_estimate = rcParams["plot.point_estimate"]
|
|
247
|
-
elif point_estimate not in {"mean", "median", "mode", None}:
|
|
248
|
-
raise ValueError("The value of point_estimate must be either mean, median, mode or None.")
|
|
249
|
-
|
|
250
|
-
if kind is None:
|
|
251
|
-
kind = rcParams["plot.density_kind"]
|
|
252
|
-
|
|
253
|
-
plotters = filter_plotters_list(
|
|
254
|
-
list(
|
|
255
|
-
xarray_var_iter(
|
|
256
|
-
get_coords(data, coords), var_names=var_names, combined=True, skip_dims=combine_dims
|
|
257
|
-
)
|
|
258
|
-
),
|
|
259
|
-
"plot_posterior",
|
|
260
|
-
)
|
|
261
|
-
length_plotters = len(plotters)
|
|
262
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
263
|
-
|
|
264
|
-
posteriorplot_kwargs = dict(
|
|
265
|
-
ax=ax,
|
|
266
|
-
length_plotters=length_plotters,
|
|
267
|
-
rows=rows,
|
|
268
|
-
cols=cols,
|
|
269
|
-
figsize=figsize,
|
|
270
|
-
plotters=plotters,
|
|
271
|
-
bw=bw,
|
|
272
|
-
circular=circular,
|
|
273
|
-
bins=bins,
|
|
274
|
-
kind=kind,
|
|
275
|
-
point_estimate=point_estimate,
|
|
276
|
-
round_to=round_to,
|
|
277
|
-
hdi_prob=hdi_prob,
|
|
278
|
-
multimodal=multimodal,
|
|
279
|
-
skipna=skipna,
|
|
280
|
-
textsize=textsize,
|
|
281
|
-
ref_val=ref_val,
|
|
282
|
-
rope=rope,
|
|
283
|
-
ref_val_color=ref_val_color,
|
|
284
|
-
rope_color=rope_color,
|
|
285
|
-
labeller=labeller,
|
|
286
|
-
kwargs=kwargs,
|
|
287
|
-
backend_kwargs=backend_kwargs,
|
|
288
|
-
show=show,
|
|
289
|
-
)
|
|
290
|
-
|
|
291
|
-
if backend is None:
|
|
292
|
-
backend = rcParams["plot.backend"]
|
|
293
|
-
backend = backend.lower()
|
|
294
|
-
|
|
295
|
-
# TODO: Add backend kwargs
|
|
296
|
-
plot = get_plotting_function("plot_posterior", "posteriorplot", backend)
|
|
297
|
-
ax = plot(**posteriorplot_kwargs)
|
|
298
|
-
return ax
|
arviz/plots/ppcplot.py
DELETED
|
@@ -1,369 +0,0 @@
|
|
|
1
|
-
"""Posterior/Prior predictive plot."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import warnings
|
|
5
|
-
from numbers import Integral
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
from ..labels import BaseLabeller
|
|
10
|
-
from ..sel_utils import xarray_var_iter
|
|
11
|
-
from ..rcparams import rcParams
|
|
12
|
-
from ..utils import _var_names
|
|
13
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
14
|
-
|
|
15
|
-
_log = logging.getLogger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def plot_ppc(
|
|
19
|
-
data,
|
|
20
|
-
kind="kde",
|
|
21
|
-
alpha=None,
|
|
22
|
-
mean=True,
|
|
23
|
-
observed=None,
|
|
24
|
-
observed_rug=False,
|
|
25
|
-
color=None,
|
|
26
|
-
colors=None,
|
|
27
|
-
grid=None,
|
|
28
|
-
figsize=None,
|
|
29
|
-
textsize=None,
|
|
30
|
-
data_pairs=None,
|
|
31
|
-
var_names=None,
|
|
32
|
-
filter_vars=None,
|
|
33
|
-
coords=None,
|
|
34
|
-
flatten=None,
|
|
35
|
-
flatten_pp=None,
|
|
36
|
-
num_pp_samples=None,
|
|
37
|
-
random_seed=None,
|
|
38
|
-
jitter=None,
|
|
39
|
-
animated=False,
|
|
40
|
-
animation_kwargs=None,
|
|
41
|
-
legend=True,
|
|
42
|
-
labeller=None,
|
|
43
|
-
ax=None,
|
|
44
|
-
backend=None,
|
|
45
|
-
backend_kwargs=None,
|
|
46
|
-
group="posterior",
|
|
47
|
-
show=None,
|
|
48
|
-
):
|
|
49
|
-
"""
|
|
50
|
-
Plot for posterior/prior predictive checks.
|
|
51
|
-
|
|
52
|
-
Parameters
|
|
53
|
-
----------
|
|
54
|
-
data : InferenceData
|
|
55
|
-
:class:`arviz.InferenceData` object containing the observed and posterior/prior
|
|
56
|
-
predictive data.
|
|
57
|
-
kind : str, default "kde"
|
|
58
|
-
Type of plot to display ("kde", "cumulative", or "scatter").
|
|
59
|
-
alpha : float, optional
|
|
60
|
-
Opacity of posterior/prior predictive density curves.
|
|
61
|
-
Defaults to 0.2 for ``kind = kde`` and cumulative, for scatter defaults to 0.7.
|
|
62
|
-
mean : bool, default True
|
|
63
|
-
Whether or not to plot the mean posterior/prior predictive distribution.
|
|
64
|
-
observed : bool, optional
|
|
65
|
-
Whether or not to plot the observed data. Defaults to True for ``group = posterior``
|
|
66
|
-
and False for ``group = prior``.
|
|
67
|
-
observed_rug : bool, default False
|
|
68
|
-
Whether or not to plot a rug plot for the observed data. Only valid if `observed` is
|
|
69
|
-
`True` and for kind `kde` or `cumulative`.
|
|
70
|
-
color : list, optional
|
|
71
|
-
List with valid matplotlib colors corresponding to the posterior/prior predictive
|
|
72
|
-
distribution, observed data and mean of the posterior/prior predictive distribution.
|
|
73
|
-
Defaults to ["C0", "k", "C1"].
|
|
74
|
-
grid : tuple, optional
|
|
75
|
-
Number of rows and columns. Defaults to None, the rows and columns are
|
|
76
|
-
automatically inferred.
|
|
77
|
-
figsize : tuple, optional
|
|
78
|
-
Figure size. If None, it will be defined automatically.
|
|
79
|
-
textsize : float, optional
|
|
80
|
-
Text size scaling factor for labels, titles and lines. If None, it will be
|
|
81
|
-
autoscaled based on ``figsize``.
|
|
82
|
-
data_pairs : dict, optional
|
|
83
|
-
Dictionary containing relations between observed data and posterior/prior predictive data.
|
|
84
|
-
Dictionary structure:
|
|
85
|
-
|
|
86
|
-
- key = data var_name
|
|
87
|
-
- value = posterior/prior predictive var_name
|
|
88
|
-
|
|
89
|
-
For example, ``data_pairs = {'y' : 'y_hat'}``
|
|
90
|
-
If None, it will assume that the observed data and the posterior/prior
|
|
91
|
-
predictive data have the same variable name.
|
|
92
|
-
var_names : list of str, optional
|
|
93
|
-
Variables to be plotted, if `None` all variable are plotted. Prefix the
|
|
94
|
-
variables by ``~`` when you want to exclude them from the plot.
|
|
95
|
-
filter_vars : {None, "like", "regex"}, default None
|
|
96
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
97
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
98
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
99
|
-
``pandas.filter``.
|
|
100
|
-
coords : dict, optional
|
|
101
|
-
Dictionary mapping dimensions to selected coordinates to be plotted.
|
|
102
|
-
Dimensions without a mapping specified will include all coordinates for
|
|
103
|
-
that dimension. Defaults to including all coordinates for all
|
|
104
|
-
dimensions if None.
|
|
105
|
-
flatten : list
|
|
106
|
-
List of dimensions to flatten in ``observed_data``. Only flattens across the coordinates
|
|
107
|
-
specified in the ``coords`` argument. Defaults to flattening all of the dimensions.
|
|
108
|
-
flatten_pp : list
|
|
109
|
-
List of dimensions to flatten in posterior_predictive/prior_predictive. Only flattens
|
|
110
|
-
across the coordinates specified in the ``coords`` argument. Defaults to flattening all
|
|
111
|
-
of the dimensions. Dimensions should match flatten excluding dimensions for ``data_pairs``
|
|
112
|
-
parameters. If ``flatten`` is defined and ``flatten_pp`` is None, then
|
|
113
|
-
``flatten_pp = flatten``.
|
|
114
|
-
num_pp_samples : int
|
|
115
|
-
The number of posterior/prior predictive samples to plot. For ``kind`` = 'scatter' and
|
|
116
|
-
``animation = False`` if defaults to a maximum of 5 samples and will set jitter to 0.7.
|
|
117
|
-
unless defined. Otherwise it defaults to all provided samples.
|
|
118
|
-
random_seed : int
|
|
119
|
-
Random number generator seed passed to ``numpy.random.seed`` to allow
|
|
120
|
-
reproducibility of the plot. By default, no seed will be provided
|
|
121
|
-
and the plot will change each call if a random sample is specified
|
|
122
|
-
by ``num_pp_samples``.
|
|
123
|
-
jitter : float, default 0
|
|
124
|
-
If ``kind`` is "scatter", jitter will add random uniform noise to the height
|
|
125
|
-
of the ppc samples and observed data.
|
|
126
|
-
animated : bool, default False
|
|
127
|
-
Create an animation of one posterior/prior predictive sample per frame.
|
|
128
|
-
Only works with matploblib backend.
|
|
129
|
-
To run animations inside a notebook you have to use the `nbAgg` matplotlib's backend.
|
|
130
|
-
Try with `%matplotlib notebook` or `%matplotlib nbAgg`. You can switch back to the
|
|
131
|
-
default matplotlib's backend with `%matplotlib inline` or `%matplotlib auto`.
|
|
132
|
-
If switching back and forth between matplotlib's backend, you may need to run twice the cell
|
|
133
|
-
with the animation.
|
|
134
|
-
If you experience problems rendering the animation try setting
|
|
135
|
-
``animation_kwargs({'blit':False})`` or changing the matplotlib's backend (e.g. to TkAgg)
|
|
136
|
-
If you run the animation from a script write ``ax, ani = az.plot_ppc(.)``
|
|
137
|
-
animation_kwargs : dict
|
|
138
|
-
Keywords passed to :class:`matplotlib.animation.FuncAnimation`. Ignored with
|
|
139
|
-
matplotlib backend.
|
|
140
|
-
legend : bool, default True
|
|
141
|
-
Add legend to figure.
|
|
142
|
-
labeller : labeller, optional
|
|
143
|
-
Class providing the method ``make_pp_label`` to generate the labels in the plot titles.
|
|
144
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
145
|
-
ax : numpy array-like of matplotlib_axes or bokeh figures, optional
|
|
146
|
-
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
|
|
147
|
-
its own array of plot areas (and return it).
|
|
148
|
-
backend : str, optional
|
|
149
|
-
Select plotting backend {"matplotlib","bokeh"}. Default to "matplotlib".
|
|
150
|
-
backend_kwargs : dict, optional
|
|
151
|
-
These are kwargs specific to the backend being used, passed to
|
|
152
|
-
:func:`matplotlib.pyplot.subplots` or :func:`bokeh.plotting.figure`.
|
|
153
|
-
For additional documentation check the plotting method of the backend.
|
|
154
|
-
group : {"prior", "posterior"}, optional
|
|
155
|
-
Specifies which InferenceData group should be plotted. Defaults to 'posterior'.
|
|
156
|
-
Other value can be 'prior'.
|
|
157
|
-
show : bool, optional
|
|
158
|
-
Call backend show function.
|
|
159
|
-
|
|
160
|
-
Returns
|
|
161
|
-
-------
|
|
162
|
-
axes : matplotlib_axes or bokeh_figures
|
|
163
|
-
ani : matplotlib.animation.FuncAnimation, optional
|
|
164
|
-
Only provided if `animated` is ``True``.
|
|
165
|
-
|
|
166
|
-
See Also
|
|
167
|
-
--------
|
|
168
|
-
plot_bpv : Plot Bayesian p-value for observed data and Posterior/Prior predictive.
|
|
169
|
-
plot_loo_pit : Plot for posterior predictive checks using cross validation.
|
|
170
|
-
plot_lm : Posterior predictive and mean plots for regression-like data.
|
|
171
|
-
plot_ts : Plot timeseries data.
|
|
172
|
-
|
|
173
|
-
Examples
|
|
174
|
-
--------
|
|
175
|
-
Plot the observed data KDE overlaid on posterior predictive KDEs.
|
|
176
|
-
|
|
177
|
-
.. plot::
|
|
178
|
-
:context: close-figs
|
|
179
|
-
|
|
180
|
-
>>> import arviz as az
|
|
181
|
-
>>> data = az.load_arviz_data('radon')
|
|
182
|
-
>>> az.plot_ppc(data, data_pairs={"y":"y"})
|
|
183
|
-
|
|
184
|
-
Plot the overlay with empirical CDFs.
|
|
185
|
-
|
|
186
|
-
.. plot::
|
|
187
|
-
:context: close-figs
|
|
188
|
-
|
|
189
|
-
>>> az.plot_ppc(data, kind='cumulative')
|
|
190
|
-
|
|
191
|
-
Use the ``coords`` and ``flatten`` parameters to plot selected variable dimensions
|
|
192
|
-
across multiple plots. We will now modify the dimension ``obs_id`` to contain
|
|
193
|
-
indicate the name of the county where the measure was taken. The change has to
|
|
194
|
-
be done on both ``posterior_predictive`` and ``observed_data`` groups, which is
|
|
195
|
-
why we will use :meth:`~arviz.InferenceData.map` to apply the same function to
|
|
196
|
-
both groups. Afterwards, we will select the counties to be plotted with the
|
|
197
|
-
``coords`` arg.
|
|
198
|
-
|
|
199
|
-
.. plot::
|
|
200
|
-
:context: close-figs
|
|
201
|
-
|
|
202
|
-
>>> obs_county = data.posterior["County"][data.constant_data["county_idx"]]
|
|
203
|
-
>>> data = data.assign_coords(obs_id=obs_county, groups="observed_vars")
|
|
204
|
-
>>> az.plot_ppc(data, coords={'obs_id': ['ANOKA', 'BELTRAMI']}, flatten=[])
|
|
205
|
-
|
|
206
|
-
Plot the overlay using a stacked scatter plot that is particularly useful
|
|
207
|
-
when the sample sizes are small.
|
|
208
|
-
|
|
209
|
-
.. plot::
|
|
210
|
-
:context: close-figs
|
|
211
|
-
|
|
212
|
-
>>> az.plot_ppc(data, kind='scatter', flatten=[],
|
|
213
|
-
>>> coords={'obs_id': ['AITKIN', 'BELTRAMI']})
|
|
214
|
-
|
|
215
|
-
Plot random posterior predictive sub-samples.
|
|
216
|
-
|
|
217
|
-
.. plot::
|
|
218
|
-
:context: close-figs
|
|
219
|
-
|
|
220
|
-
>>> az.plot_ppc(data, num_pp_samples=30, random_seed=7)
|
|
221
|
-
"""
|
|
222
|
-
if group not in ("posterior", "prior"):
|
|
223
|
-
raise TypeError("`group` argument must be either `posterior` or `prior`")
|
|
224
|
-
|
|
225
|
-
for groups in (f"{group}_predictive", "observed_data"):
|
|
226
|
-
if not hasattr(data, groups):
|
|
227
|
-
raise TypeError(f'`data` argument must have the group "{groups}" for ppcplot')
|
|
228
|
-
|
|
229
|
-
if kind.lower() not in ("kde", "cumulative", "scatter"):
|
|
230
|
-
raise TypeError("`kind` argument must be either `kde`, `cumulative`, or `scatter`")
|
|
231
|
-
|
|
232
|
-
if colors is None:
|
|
233
|
-
colors = ["C0", "k", "C1"]
|
|
234
|
-
|
|
235
|
-
if isinstance(colors, str):
|
|
236
|
-
raise TypeError("colors should be a list with 3 items.")
|
|
237
|
-
|
|
238
|
-
if len(colors) != 3:
|
|
239
|
-
raise ValueError("colors should be a list with 3 items.")
|
|
240
|
-
|
|
241
|
-
if color is not None:
|
|
242
|
-
warnings.warn("color has been deprecated in favor of colors", FutureWarning)
|
|
243
|
-
colors[0] = color
|
|
244
|
-
|
|
245
|
-
if data_pairs is None:
|
|
246
|
-
data_pairs = {}
|
|
247
|
-
|
|
248
|
-
if backend is None:
|
|
249
|
-
backend = rcParams["plot.backend"]
|
|
250
|
-
backend = backend.lower()
|
|
251
|
-
if backend == "bokeh" and animated:
|
|
252
|
-
raise TypeError("Animation option is only supported with matplotlib backend.")
|
|
253
|
-
|
|
254
|
-
observed_data = data.observed_data
|
|
255
|
-
|
|
256
|
-
if group == "posterior":
|
|
257
|
-
predictive_dataset = data.posterior_predictive
|
|
258
|
-
if observed is None:
|
|
259
|
-
observed = True
|
|
260
|
-
elif group == "prior":
|
|
261
|
-
predictive_dataset = data.prior_predictive
|
|
262
|
-
if observed is None:
|
|
263
|
-
observed = False
|
|
264
|
-
|
|
265
|
-
if var_names is None:
|
|
266
|
-
var_names = list(observed_data.data_vars)
|
|
267
|
-
var_names = _var_names(var_names, observed_data, filter_vars)
|
|
268
|
-
pp_var_names = [data_pairs.get(var, var) for var in var_names]
|
|
269
|
-
pp_var_names = _var_names(pp_var_names, predictive_dataset, filter_vars)
|
|
270
|
-
|
|
271
|
-
if flatten_pp is None:
|
|
272
|
-
if flatten is None:
|
|
273
|
-
flatten_pp = list(predictive_dataset.dims)
|
|
274
|
-
else:
|
|
275
|
-
flatten_pp = flatten
|
|
276
|
-
if flatten is None:
|
|
277
|
-
flatten = list(observed_data.dims)
|
|
278
|
-
|
|
279
|
-
if coords is None:
|
|
280
|
-
coords = {}
|
|
281
|
-
else:
|
|
282
|
-
coords = coords.copy()
|
|
283
|
-
|
|
284
|
-
if labeller is None:
|
|
285
|
-
labeller = BaseLabeller()
|
|
286
|
-
|
|
287
|
-
if random_seed is not None:
|
|
288
|
-
np.random.seed(random_seed)
|
|
289
|
-
|
|
290
|
-
total_pp_samples = predictive_dataset.sizes["chain"] * predictive_dataset.sizes["draw"]
|
|
291
|
-
if num_pp_samples is None:
|
|
292
|
-
if kind == "scatter" and not animated:
|
|
293
|
-
num_pp_samples = min(5, total_pp_samples)
|
|
294
|
-
else:
|
|
295
|
-
num_pp_samples = total_pp_samples
|
|
296
|
-
|
|
297
|
-
if (
|
|
298
|
-
not isinstance(num_pp_samples, Integral)
|
|
299
|
-
or num_pp_samples < 1
|
|
300
|
-
or num_pp_samples > total_pp_samples
|
|
301
|
-
):
|
|
302
|
-
raise TypeError(f"`num_pp_samples` must be an integer between 1 and {total_pp_samples}.")
|
|
303
|
-
|
|
304
|
-
pp_sample_ix = np.random.choice(total_pp_samples, size=num_pp_samples, replace=False)
|
|
305
|
-
|
|
306
|
-
for key in coords.keys():
|
|
307
|
-
coords[key] = np.where(np.isin(observed_data[key], coords[key]))[0]
|
|
308
|
-
|
|
309
|
-
obs_plotters = filter_plotters_list(
|
|
310
|
-
list(
|
|
311
|
-
xarray_var_iter(
|
|
312
|
-
observed_data.isel(coords),
|
|
313
|
-
skip_dims=set(flatten),
|
|
314
|
-
var_names=var_names,
|
|
315
|
-
combined=True,
|
|
316
|
-
dim_order=["chain", "draw"],
|
|
317
|
-
)
|
|
318
|
-
),
|
|
319
|
-
"plot_ppc",
|
|
320
|
-
)
|
|
321
|
-
length_plotters = len(obs_plotters)
|
|
322
|
-
pp_plotters = [
|
|
323
|
-
tup
|
|
324
|
-
for _, tup in zip(
|
|
325
|
-
range(length_plotters),
|
|
326
|
-
xarray_var_iter(
|
|
327
|
-
predictive_dataset.isel(coords),
|
|
328
|
-
var_names=pp_var_names,
|
|
329
|
-
skip_dims=set(flatten_pp),
|
|
330
|
-
combined=True,
|
|
331
|
-
dim_order=["chain", "draw"],
|
|
332
|
-
),
|
|
333
|
-
)
|
|
334
|
-
]
|
|
335
|
-
rows, cols = default_grid(length_plotters, grid=grid)
|
|
336
|
-
|
|
337
|
-
ppcplot_kwargs = dict(
|
|
338
|
-
ax=ax,
|
|
339
|
-
length_plotters=length_plotters,
|
|
340
|
-
rows=rows,
|
|
341
|
-
cols=cols,
|
|
342
|
-
figsize=figsize,
|
|
343
|
-
animated=animated,
|
|
344
|
-
obs_plotters=obs_plotters,
|
|
345
|
-
pp_plotters=pp_plotters,
|
|
346
|
-
predictive_dataset=predictive_dataset,
|
|
347
|
-
pp_sample_ix=pp_sample_ix,
|
|
348
|
-
kind=kind,
|
|
349
|
-
alpha=alpha,
|
|
350
|
-
colors=colors,
|
|
351
|
-
jitter=jitter,
|
|
352
|
-
textsize=textsize,
|
|
353
|
-
mean=mean,
|
|
354
|
-
observed=observed,
|
|
355
|
-
observed_rug=observed_rug,
|
|
356
|
-
total_pp_samples=total_pp_samples,
|
|
357
|
-
legend=legend,
|
|
358
|
-
labeller=labeller,
|
|
359
|
-
group=group,
|
|
360
|
-
animation_kwargs=animation_kwargs,
|
|
361
|
-
num_pp_samples=num_pp_samples,
|
|
362
|
-
backend_kwargs=backend_kwargs,
|
|
363
|
-
show=show,
|
|
364
|
-
)
|
|
365
|
-
|
|
366
|
-
# TODO: Add backend kwargs
|
|
367
|
-
plot = get_plotting_function("plot_ppc", "ppcplot", backend)
|
|
368
|
-
axes = plot(**ppcplot_kwargs)
|
|
369
|
-
return axes
|