arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/plots/khatplot.py
DELETED
|
@@ -1,236 +0,0 @@
|
|
|
1
|
-
"""Pareto tail indices plot."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import warnings
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
from xarray import DataArray
|
|
8
|
-
|
|
9
|
-
from ..rcparams import rcParams
|
|
10
|
-
from ..stats import ELPDData
|
|
11
|
-
from ..utils import get_coords
|
|
12
|
-
from .plot_utils import format_coords_as_labels, get_plotting_function
|
|
13
|
-
|
|
14
|
-
_log = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def plot_khat(
|
|
18
|
-
khats,
|
|
19
|
-
color="C0",
|
|
20
|
-
xlabels=False,
|
|
21
|
-
show_hlines=False,
|
|
22
|
-
show_bins=False,
|
|
23
|
-
bin_format="{1:.1f}%",
|
|
24
|
-
annotate=False,
|
|
25
|
-
threshold=None,
|
|
26
|
-
hover_label=False,
|
|
27
|
-
hover_format="{1}",
|
|
28
|
-
figsize=None,
|
|
29
|
-
textsize=None,
|
|
30
|
-
coords=None,
|
|
31
|
-
legend=False,
|
|
32
|
-
markersize=None,
|
|
33
|
-
ax=None,
|
|
34
|
-
hlines_kwargs=None,
|
|
35
|
-
backend=None,
|
|
36
|
-
backend_kwargs=None,
|
|
37
|
-
show=None,
|
|
38
|
-
**kwargs
|
|
39
|
-
):
|
|
40
|
-
r"""Plot Pareto tail indices :math:`\hat{k}` for diagnosing convergence in PSIS-LOO.
|
|
41
|
-
|
|
42
|
-
Parameters
|
|
43
|
-
----------
|
|
44
|
-
khats : ELPDData
|
|
45
|
-
The input Pareto tail indices to be plotted.
|
|
46
|
-
color : str or array_like, default "C0"
|
|
47
|
-
Colors of the scatter plot, if color is a str all dots will have the same color,
|
|
48
|
-
if it is the size of the observations, each dot will have the specified color,
|
|
49
|
-
otherwise, it will be interpreted as a list of the dims to be used for the color
|
|
50
|
-
code. If Matplotlib c argument is passed, it will override the color argument.
|
|
51
|
-
xlabels : bool, default False
|
|
52
|
-
Use coords as xticklabels.
|
|
53
|
-
show_hlines : bool, default False
|
|
54
|
-
Show the horizontal lines, by default at the values [0, 0.5, 0.7, 1].
|
|
55
|
-
show_bins : bool, default False
|
|
56
|
-
Show the percentage of khats falling in each bin, as delimited by hlines.
|
|
57
|
-
bin_format : str, optional
|
|
58
|
-
The string is used as formatting guide calling ``bin_format.format(count, pct)``.
|
|
59
|
-
threshold : float, optional
|
|
60
|
-
Show the labels of k values larger than `threshold`. If ``None`` (default), no
|
|
61
|
-
observations will be highlighted.
|
|
62
|
-
hover_label : bool, default False
|
|
63
|
-
Show the datapoint label when hovering over it with the mouse. Requires an interactive
|
|
64
|
-
backend.
|
|
65
|
-
hover_format : str, default "{1}"
|
|
66
|
-
String used to format the hover label via ``hover_format.format(idx, coord_label)``
|
|
67
|
-
figsize : (float, float), optional
|
|
68
|
-
Figure size. If ``None`` it will be defined automatically.
|
|
69
|
-
textsize : float, optional
|
|
70
|
-
Text size scaling factor for labels, titles and lines. If ``None`` it will be autoscaled
|
|
71
|
-
based on `figsize`.
|
|
72
|
-
coords : mapping, optional
|
|
73
|
-
Coordinates of points to plot. **All** values are used for computation, but only a
|
|
74
|
-
a subset can be plotted for convenience. See :ref:`this section <common_coords>` for
|
|
75
|
-
usage examples.
|
|
76
|
-
legend : bool, default False
|
|
77
|
-
Include a legend to the plot. Only taken into account when color argument is a dim name.
|
|
78
|
-
markersize : int, optional
|
|
79
|
-
markersize for scatter plot. Defaults to ``None`` in which case it will
|
|
80
|
-
be chosen based on autoscaling for figsize.
|
|
81
|
-
ax : axes, optional
|
|
82
|
-
Matplotlib axes or bokeh figures.
|
|
83
|
-
hlines_kwargs : dict, optional
|
|
84
|
-
Additional keywords passed to
|
|
85
|
-
:meth:`matplotlib.axes.Axes.hlines`.
|
|
86
|
-
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
87
|
-
Select plotting backend.
|
|
88
|
-
backend_kwargs : dict, optional
|
|
89
|
-
These are kwargs specific to the backend being used, passed to
|
|
90
|
-
:func:`matplotlib.pyplot.subplots` or :class:`bokeh.plotting.figure`.
|
|
91
|
-
For additional documentation check the plotting method of the backend.
|
|
92
|
-
show : bool, optional
|
|
93
|
-
Call backend show function.
|
|
94
|
-
kwargs :
|
|
95
|
-
Additional keywords passed to
|
|
96
|
-
:meth:`matplotlib.axes.Axes.scatter`.
|
|
97
|
-
|
|
98
|
-
Returns
|
|
99
|
-
-------
|
|
100
|
-
axes : matplotlib_axes or bokeh_figures
|
|
101
|
-
|
|
102
|
-
See Also
|
|
103
|
-
--------
|
|
104
|
-
psislw : Pareto smoothed importance sampling (PSIS).
|
|
105
|
-
|
|
106
|
-
Examples
|
|
107
|
-
--------
|
|
108
|
-
Plot estimated pareto shape parameters showing how many fall in each category.
|
|
109
|
-
|
|
110
|
-
.. plot::
|
|
111
|
-
:context: close-figs
|
|
112
|
-
|
|
113
|
-
>>> import arviz as az
|
|
114
|
-
>>> radon = az.load_arviz_data("radon")
|
|
115
|
-
>>> loo_radon = az.loo(radon, pointwise=True)
|
|
116
|
-
>>> az.plot_khat(loo_radon, show_bins=True)
|
|
117
|
-
|
|
118
|
-
Show xlabels
|
|
119
|
-
|
|
120
|
-
.. plot::
|
|
121
|
-
:context: close-figs
|
|
122
|
-
|
|
123
|
-
>>> centered_eight = az.load_arviz_data("centered_eight")
|
|
124
|
-
>>> khats = az.loo(centered_eight, pointwise=True).pareto_k
|
|
125
|
-
>>> az.plot_khat(khats, xlabels=True, threshold=1)
|
|
126
|
-
|
|
127
|
-
Use custom color scheme
|
|
128
|
-
|
|
129
|
-
.. plot::
|
|
130
|
-
:context: close-figs
|
|
131
|
-
|
|
132
|
-
>>> counties = radon.posterior.County[radon.constant_data.county_idx].values
|
|
133
|
-
>>> colors = [
|
|
134
|
-
... "blue" if county[-1] in ("A", "N") else "green" for county in counties
|
|
135
|
-
... ]
|
|
136
|
-
>>> az.plot_khat(loo_radon, color=colors)
|
|
137
|
-
|
|
138
|
-
Notes
|
|
139
|
-
-----
|
|
140
|
-
The Generalized Pareto distribution (GPD) diagnoses convergence rates for importance
|
|
141
|
-
sampling. GPD has parameters offset, scale, and shape. The shape parameter (:math:`k`)
|
|
142
|
-
tells the distribution's number of finite moments. The pre-asymptotic convergence rate
|
|
143
|
-
of importance sampling can be estimated based on the fractional number of finite moments
|
|
144
|
-
of the importance ratio distribution. GPD is fitted to the largest importance ratios and
|
|
145
|
-
interprets the estimated shape parameter :math:`k`, i.e., :math:`\hat{k}` can then be
|
|
146
|
-
used as a diagnostic (most importantly if :math:`\hat{k} > 0.7`, then the convergence
|
|
147
|
-
rate is impractically low). See [1]_.
|
|
148
|
-
|
|
149
|
-
References
|
|
150
|
-
----------
|
|
151
|
-
.. [1] Vehtari, A., Simpson, D., Gelman, A., Yao, Y., Gabry, J. (2024).
|
|
152
|
-
Pareto Smoothed Importance Sampling. Journal of Machine Learning
|
|
153
|
-
Research, 25(72):1-58.
|
|
154
|
-
|
|
155
|
-
"""
|
|
156
|
-
if annotate:
|
|
157
|
-
_log.warning("annotate will be deprecated, please use threshold instead")
|
|
158
|
-
threshold = annotate
|
|
159
|
-
|
|
160
|
-
if coords is None:
|
|
161
|
-
coords = {}
|
|
162
|
-
|
|
163
|
-
if color is None:
|
|
164
|
-
color = "C0"
|
|
165
|
-
|
|
166
|
-
if isinstance(khats, np.ndarray):
|
|
167
|
-
warnings.warn(
|
|
168
|
-
"support for arrays will be deprecated, please use ELPDData."
|
|
169
|
-
"The reason for this, is that we need to know the numbers of draws"
|
|
170
|
-
"sampled from the posterior",
|
|
171
|
-
FutureWarning,
|
|
172
|
-
)
|
|
173
|
-
khats = khats.flatten()
|
|
174
|
-
xlabels = False
|
|
175
|
-
legend = False
|
|
176
|
-
dims = []
|
|
177
|
-
good_k = None
|
|
178
|
-
else:
|
|
179
|
-
if isinstance(khats, ELPDData):
|
|
180
|
-
good_k = khats.good_k
|
|
181
|
-
khats = khats.pareto_k
|
|
182
|
-
else:
|
|
183
|
-
good_k = None
|
|
184
|
-
warnings.warn(
|
|
185
|
-
"support for DataArrays will be deprecated, please use ELPDData."
|
|
186
|
-
"The reason for this, is that we need to know the numbers of draws"
|
|
187
|
-
"sampled from the posterior",
|
|
188
|
-
FutureWarning,
|
|
189
|
-
)
|
|
190
|
-
if not isinstance(khats, DataArray):
|
|
191
|
-
raise ValueError("Incorrect khat data input. Check the documentation")
|
|
192
|
-
|
|
193
|
-
khats = get_coords(khats, coords)
|
|
194
|
-
dims = khats.dims
|
|
195
|
-
|
|
196
|
-
n_data_points = khats.size
|
|
197
|
-
xdata = np.arange(n_data_points)
|
|
198
|
-
if isinstance(khats, DataArray):
|
|
199
|
-
coord_labels = format_coords_as_labels(khats)
|
|
200
|
-
else:
|
|
201
|
-
coord_labels = xdata.astype(str)
|
|
202
|
-
|
|
203
|
-
plot_khat_kwargs = dict(
|
|
204
|
-
hover_label=hover_label,
|
|
205
|
-
hover_format=hover_format,
|
|
206
|
-
ax=ax,
|
|
207
|
-
figsize=figsize,
|
|
208
|
-
xdata=xdata,
|
|
209
|
-
khats=khats,
|
|
210
|
-
good_k=good_k,
|
|
211
|
-
kwargs=kwargs,
|
|
212
|
-
threshold=threshold,
|
|
213
|
-
coord_labels=coord_labels,
|
|
214
|
-
show_hlines=show_hlines,
|
|
215
|
-
show_bins=show_bins,
|
|
216
|
-
hlines_kwargs=hlines_kwargs,
|
|
217
|
-
xlabels=xlabels,
|
|
218
|
-
legend=legend,
|
|
219
|
-
color=color,
|
|
220
|
-
dims=dims,
|
|
221
|
-
textsize=textsize,
|
|
222
|
-
markersize=markersize,
|
|
223
|
-
n_data_points=n_data_points,
|
|
224
|
-
bin_format=bin_format,
|
|
225
|
-
backend_kwargs=backend_kwargs,
|
|
226
|
-
show=show,
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
if backend is None:
|
|
230
|
-
backend = rcParams["plot.backend"]
|
|
231
|
-
backend = backend.lower()
|
|
232
|
-
|
|
233
|
-
# TODO: Add backend kwargs
|
|
234
|
-
plot = get_plotting_function("plot_khat", "khatplot", backend)
|
|
235
|
-
axes = plot(**plot_khat_kwargs)
|
|
236
|
-
return axes
|
arviz/plots/lmplot.py
DELETED
|
@@ -1,380 +0,0 @@
|
|
|
1
|
-
"""Plot regression figure."""
|
|
2
|
-
|
|
3
|
-
import warnings
|
|
4
|
-
from numbers import Integral
|
|
5
|
-
from itertools import repeat
|
|
6
|
-
|
|
7
|
-
import xarray as xr
|
|
8
|
-
import numpy as np
|
|
9
|
-
from xarray.core.dataarray import DataArray
|
|
10
|
-
|
|
11
|
-
from ..sel_utils import xarray_var_iter
|
|
12
|
-
from ..rcparams import rcParams
|
|
13
|
-
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def _repeat_flatten_list(lst, n):
|
|
17
|
-
return [item for sublist in repeat(lst, n) for item in sublist]
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def plot_lm(
|
|
21
|
-
y,
|
|
22
|
-
idata=None,
|
|
23
|
-
x=None,
|
|
24
|
-
y_model=None,
|
|
25
|
-
y_hat=None,
|
|
26
|
-
num_samples=50,
|
|
27
|
-
kind_pp="samples",
|
|
28
|
-
kind_model="lines",
|
|
29
|
-
xjitter=False,
|
|
30
|
-
plot_dim=None,
|
|
31
|
-
backend=None,
|
|
32
|
-
y_kwargs=None,
|
|
33
|
-
y_hat_plot_kwargs=None,
|
|
34
|
-
y_hat_fill_kwargs=None,
|
|
35
|
-
y_model_plot_kwargs=None,
|
|
36
|
-
y_model_fill_kwargs=None,
|
|
37
|
-
y_model_mean_kwargs=None,
|
|
38
|
-
backend_kwargs=None,
|
|
39
|
-
show=None,
|
|
40
|
-
figsize=None,
|
|
41
|
-
textsize=None,
|
|
42
|
-
axes=None,
|
|
43
|
-
legend=True,
|
|
44
|
-
grid=True,
|
|
45
|
-
):
|
|
46
|
-
"""Posterior predictive and mean plots for regression-like data.
|
|
47
|
-
|
|
48
|
-
Parameters
|
|
49
|
-
----------
|
|
50
|
-
y : str or DataArray or ndarray
|
|
51
|
-
If str, variable name from ``observed_data``.
|
|
52
|
-
idata : InferenceData, Optional
|
|
53
|
-
Optional only if ``y`` is not str.
|
|
54
|
-
x : str, tuple of strings, DataArray or array-like, optional
|
|
55
|
-
If str or tuple, variable name from ``constant_data``.
|
|
56
|
-
If ndarray, could be 1D, or 2D for multiple plots.
|
|
57
|
-
If None, coords name of ``y`` (``y`` should be DataArray).
|
|
58
|
-
y_model : str or Sequence, Optional
|
|
59
|
-
If str, variable name from ``posterior``.
|
|
60
|
-
Its dimensions should be same as ``y`` plus added chains and draws.
|
|
61
|
-
y_hat : str, Optional
|
|
62
|
-
If str, variable name from ``posterior_predictive``.
|
|
63
|
-
Its dimensions should be same as ``y`` plus added chains and draws.
|
|
64
|
-
num_samples : int, Optional, Default 50
|
|
65
|
-
Significant if ``kind_pp`` is "samples" or ``kind_model`` is "lines".
|
|
66
|
-
Number of samples to be drawn from posterior predictive or
|
|
67
|
-
kind_pp : {"samples", "hdi"}, Default "samples"
|
|
68
|
-
Options to visualize uncertainty in data.
|
|
69
|
-
kind_model : {"lines", "hdi"}, Default "lines"
|
|
70
|
-
Options to visualize uncertainty in mean of the data.
|
|
71
|
-
plot_dim : str, Optional
|
|
72
|
-
Necessary if ``y`` is multidimensional.
|
|
73
|
-
backend : str, Optional
|
|
74
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
75
|
-
y_kwargs : dict, optional
|
|
76
|
-
Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
|
|
77
|
-
and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh
|
|
78
|
-
y_hat_plot_kwargs : dict, optional
|
|
79
|
-
Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
|
|
80
|
-
and :meth:`bokeh:bokeh.plotting.Figure.circle` in bokeh
|
|
81
|
-
y_hat_fill_kwargs : dict, optional
|
|
82
|
-
Passed to :func:`arviz.plot_hdi`
|
|
83
|
-
y_model_plot_kwargs : dict, optional
|
|
84
|
-
Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
|
|
85
|
-
and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh
|
|
86
|
-
y_model_fill_kwargs : dict, optional
|
|
87
|
-
Significant if ``kind_model`` is "hdi". Passed to :func:`arviz.plot_hdi`
|
|
88
|
-
y_model_mean_kwargs : dict, optional
|
|
89
|
-
Passed to :meth:`mpl:matplotlib.axes.Axes.plot` in matplotlib
|
|
90
|
-
and :meth:`bokeh:bokeh.plotting.Figure.line` in bokeh
|
|
91
|
-
backend_kwargs : dict, optional
|
|
92
|
-
These are kwargs specific to the backend being used. Passed to
|
|
93
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
94
|
-
:func:`bokeh.plotting.figure`.
|
|
95
|
-
figsize : (float, float), optional
|
|
96
|
-
Figure size. If None it will be defined automatically.
|
|
97
|
-
textsize : float, optional
|
|
98
|
-
Text size scaling factor for labels, titles and lines. If None it will be
|
|
99
|
-
autoscaled based on ``figsize``.
|
|
100
|
-
axes : 2D numpy array-like of matplotlib_axes or bokeh_figures, optional
|
|
101
|
-
A 2D array of locations into which to plot the densities. If not supplied, Arviz will create
|
|
102
|
-
its own array of plot areas (and return it).
|
|
103
|
-
show : bool, optional
|
|
104
|
-
Call backend show function.
|
|
105
|
-
legend : bool, optional
|
|
106
|
-
Add legend to figure. By default True.
|
|
107
|
-
grid : bool, optional
|
|
108
|
-
Add grid to figure. By default True.
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
Returns
|
|
112
|
-
-------
|
|
113
|
-
axes: matplotlib axes or bokeh figures
|
|
114
|
-
|
|
115
|
-
See Also
|
|
116
|
-
--------
|
|
117
|
-
plot_ts : Plot timeseries data
|
|
118
|
-
plot_ppc : Plot for posterior/prior predictive checks
|
|
119
|
-
|
|
120
|
-
Examples
|
|
121
|
-
--------
|
|
122
|
-
Plot regression default plot
|
|
123
|
-
|
|
124
|
-
.. plot::
|
|
125
|
-
:context: close-figs
|
|
126
|
-
|
|
127
|
-
>>> import arviz as az
|
|
128
|
-
>>> import numpy as np
|
|
129
|
-
>>> import xarray as xr
|
|
130
|
-
>>> idata = az.load_arviz_data('regression1d')
|
|
131
|
-
>>> x = xr.DataArray(np.linspace(0, 1, 100))
|
|
132
|
-
>>> idata.posterior["y_model"] = idata.posterior["intercept"] + idata.posterior["slope"]*x
|
|
133
|
-
>>> az.plot_lm(idata=idata, y="y", x=x)
|
|
134
|
-
|
|
135
|
-
Plot regression data and mean uncertainty
|
|
136
|
-
|
|
137
|
-
.. plot::
|
|
138
|
-
:context: close-figs
|
|
139
|
-
|
|
140
|
-
>>> az.plot_lm(idata=idata, y="y", x=x, y_model="y_model")
|
|
141
|
-
|
|
142
|
-
Plot regression data and mean uncertainty in hdi form
|
|
143
|
-
|
|
144
|
-
.. plot::
|
|
145
|
-
:context: close-figs
|
|
146
|
-
|
|
147
|
-
>>> az.plot_lm(
|
|
148
|
-
... idata=idata, y="y", x=x, y_model="y_model", kind_pp="hdi", kind_model="hdi"
|
|
149
|
-
... )
|
|
150
|
-
|
|
151
|
-
Plot regression data for multi-dimensional y using plot_dim
|
|
152
|
-
|
|
153
|
-
.. plot::
|
|
154
|
-
:context: close-figs
|
|
155
|
-
|
|
156
|
-
>>> data = az.from_dict(
|
|
157
|
-
... observed_data = { "y": np.random.normal(size=(5, 7)) },
|
|
158
|
-
... posterior_predictive = {"y": np.random.randn(4, 1000, 5, 7) / 2},
|
|
159
|
-
... dims={"y": ["dim1", "dim2"]},
|
|
160
|
-
... coords={"dim1": range(5), "dim2": range(7)}
|
|
161
|
-
... )
|
|
162
|
-
>>> az.plot_lm(idata=data, y="y", plot_dim="dim1")
|
|
163
|
-
"""
|
|
164
|
-
if kind_pp not in ("samples", "hdi"):
|
|
165
|
-
raise ValueError("kind_ppc should be either samples or hdi")
|
|
166
|
-
|
|
167
|
-
if kind_model not in ("lines", "hdi"):
|
|
168
|
-
raise ValueError("kind_model should be either lines or hdi")
|
|
169
|
-
|
|
170
|
-
if y_hat is None and isinstance(y, str):
|
|
171
|
-
y_hat = y
|
|
172
|
-
|
|
173
|
-
if isinstance(y, str):
|
|
174
|
-
y = idata.observed_data[y]
|
|
175
|
-
elif not isinstance(y, DataArray):
|
|
176
|
-
y = xr.DataArray(y)
|
|
177
|
-
|
|
178
|
-
if len(y.dims) > 1 and plot_dim is None:
|
|
179
|
-
raise ValueError("Argument plot_dim is needed in case of multidimensional data")
|
|
180
|
-
|
|
181
|
-
x_var_names = None
|
|
182
|
-
if isinstance(x, str):
|
|
183
|
-
x = idata.constant_data[x]
|
|
184
|
-
x_skip_dims = x.dims
|
|
185
|
-
elif isinstance(x, tuple):
|
|
186
|
-
x_var_names = x
|
|
187
|
-
x = idata.constant_data
|
|
188
|
-
x_skip_dims = x.dims
|
|
189
|
-
elif isinstance(x, DataArray):
|
|
190
|
-
x_skip_dims = x.dims
|
|
191
|
-
elif x is None:
|
|
192
|
-
x = y.coords[y.dims[0]] if plot_dim is None else y.coords[plot_dim]
|
|
193
|
-
x_skip_dims = x.dims
|
|
194
|
-
else:
|
|
195
|
-
x = xr.DataArray(x)
|
|
196
|
-
x_skip_dims = [x.dims[-1]]
|
|
197
|
-
|
|
198
|
-
# If posterior is present in idata and y_hat is there, get its values
|
|
199
|
-
if isinstance(y_model, str):
|
|
200
|
-
if "posterior" not in idata.groups():
|
|
201
|
-
warnings.warn("Posterior not found in idata", UserWarning)
|
|
202
|
-
y_model = None
|
|
203
|
-
elif hasattr(idata.posterior, y_model):
|
|
204
|
-
y_model = idata.posterior[y_model]
|
|
205
|
-
else:
|
|
206
|
-
warnings.warn("y_model not found in posterior", UserWarning)
|
|
207
|
-
y_model = None
|
|
208
|
-
|
|
209
|
-
# If posterior_predictive is present in idata and y_hat is there, get its values
|
|
210
|
-
if isinstance(y_hat, str):
|
|
211
|
-
if "posterior_predictive" not in idata.groups():
|
|
212
|
-
warnings.warn("posterior_predictive not found in idata", UserWarning)
|
|
213
|
-
y_hat = None
|
|
214
|
-
elif hasattr(idata.posterior_predictive, y_hat):
|
|
215
|
-
y_hat = idata.posterior_predictive[y_hat]
|
|
216
|
-
else:
|
|
217
|
-
warnings.warn("y_hat not found in posterior_predictive", UserWarning)
|
|
218
|
-
y_hat = None
|
|
219
|
-
|
|
220
|
-
# Check if num_pp_smaples is valid and generate num_pp_smaples number of random indexes.
|
|
221
|
-
# Only needed if kind_pp="samples" or kind_model="lines". Not req for plotting hdi
|
|
222
|
-
pp_sample_ix = None
|
|
223
|
-
if (y_hat is not None and kind_pp == "samples") or (
|
|
224
|
-
y_model is not None and kind_model == "lines"
|
|
225
|
-
):
|
|
226
|
-
if y_hat is not None:
|
|
227
|
-
total_pp_samples = y_hat.sizes["chain"] * y_hat.sizes["draw"]
|
|
228
|
-
else:
|
|
229
|
-
total_pp_samples = y_model.sizes["chain"] * y_model.sizes["draw"]
|
|
230
|
-
|
|
231
|
-
if (
|
|
232
|
-
not isinstance(num_samples, Integral)
|
|
233
|
-
or num_samples < 1
|
|
234
|
-
or num_samples > total_pp_samples
|
|
235
|
-
):
|
|
236
|
-
raise TypeError(f"`num_samples` must be an integer between 1 and {total_pp_samples}.")
|
|
237
|
-
|
|
238
|
-
pp_sample_ix = np.random.choice(total_pp_samples, size=num_samples, replace=False)
|
|
239
|
-
|
|
240
|
-
# crucial step in case of multidim y
|
|
241
|
-
if plot_dim is None:
|
|
242
|
-
skip_dims = list(y.dims)
|
|
243
|
-
elif isinstance(plot_dim, str):
|
|
244
|
-
skip_dims = [plot_dim]
|
|
245
|
-
elif isinstance(plot_dim, tuple):
|
|
246
|
-
skip_dims = list(plot_dim)
|
|
247
|
-
|
|
248
|
-
# Generate x axis plotters.
|
|
249
|
-
x = filter_plotters_list(
|
|
250
|
-
plotters=list(
|
|
251
|
-
xarray_var_iter(
|
|
252
|
-
x,
|
|
253
|
-
var_names=x_var_names,
|
|
254
|
-
skip_dims=set(x_skip_dims),
|
|
255
|
-
combined=True,
|
|
256
|
-
)
|
|
257
|
-
),
|
|
258
|
-
plot_kind="plot_lm",
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
# Generate y axis plotters
|
|
262
|
-
y = filter_plotters_list(
|
|
263
|
-
plotters=list(
|
|
264
|
-
xarray_var_iter(
|
|
265
|
-
y,
|
|
266
|
-
skip_dims=set(skip_dims),
|
|
267
|
-
combined=True,
|
|
268
|
-
)
|
|
269
|
-
),
|
|
270
|
-
plot_kind="plot_lm",
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
# If there are multiple x and multidimensional y, we need total of len(x)*len(y) graphs
|
|
274
|
-
len_y = len(y)
|
|
275
|
-
len_x = len(x)
|
|
276
|
-
length_plotters = len_x * len_y
|
|
277
|
-
y = _repeat_flatten_list(y, len_x)
|
|
278
|
-
x = _repeat_flatten_list(x, len_y)
|
|
279
|
-
|
|
280
|
-
# Filter out the required values to generate plotters
|
|
281
|
-
if y_hat is not None:
|
|
282
|
-
if kind_pp == "samples":
|
|
283
|
-
y_hat = y_hat.stack(__sample__=("chain", "draw"))[..., pp_sample_ix]
|
|
284
|
-
skip_dims += ["__sample__"]
|
|
285
|
-
|
|
286
|
-
y_hat = [
|
|
287
|
-
tup
|
|
288
|
-
for _, tup in zip(
|
|
289
|
-
range(len_y),
|
|
290
|
-
xarray_var_iter(
|
|
291
|
-
y_hat,
|
|
292
|
-
skip_dims=set(skip_dims),
|
|
293
|
-
combined=True,
|
|
294
|
-
),
|
|
295
|
-
)
|
|
296
|
-
]
|
|
297
|
-
|
|
298
|
-
y_hat = _repeat_flatten_list(y_hat, len_x)
|
|
299
|
-
|
|
300
|
-
# Filter out the required values to generate plotters
|
|
301
|
-
if y_model is not None:
|
|
302
|
-
if kind_model == "lines":
|
|
303
|
-
var_name = y_model.name if y_model.name else "y_model"
|
|
304
|
-
data = y_model.values
|
|
305
|
-
|
|
306
|
-
total_samples = data.shape[0] * data.shape[1]
|
|
307
|
-
data = data.reshape(total_samples, *data.shape[2:])
|
|
308
|
-
|
|
309
|
-
if pp_sample_ix is not None:
|
|
310
|
-
data = data[pp_sample_ix]
|
|
311
|
-
|
|
312
|
-
if plot_dim is not None:
|
|
313
|
-
# For plot_dim case, transpose to get dimension first
|
|
314
|
-
data = data.transpose(1, 0, 2)[..., 0]
|
|
315
|
-
|
|
316
|
-
# Create plotter tuple(s)
|
|
317
|
-
if plot_dim is not None:
|
|
318
|
-
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
319
|
-
else:
|
|
320
|
-
y_model = [(var_name, {}, {}, data)]
|
|
321
|
-
y_model = _repeat_flatten_list(y_model, len_x)
|
|
322
|
-
|
|
323
|
-
elif kind_model == "hdi":
|
|
324
|
-
var_name = y_model.name if y_model.name else "y_model"
|
|
325
|
-
data = y_model.values
|
|
326
|
-
|
|
327
|
-
if plot_dim is not None:
|
|
328
|
-
# First transpose to get plot_dim first
|
|
329
|
-
data = data.transpose(2, 0, 1, 3)
|
|
330
|
-
# For plot_dim case, we just want HDI for first dimension
|
|
331
|
-
data = data[..., 0]
|
|
332
|
-
|
|
333
|
-
# Reshape to (samples, points)
|
|
334
|
-
data = data.transpose(1, 2, 0).reshape(-1, data.shape[0])
|
|
335
|
-
y_model = [(var_name, {}, {}, data) for _ in range(length_plotters)]
|
|
336
|
-
|
|
337
|
-
else:
|
|
338
|
-
data = data.reshape(-1, data.shape[-1])
|
|
339
|
-
y_model = [(var_name, {}, {}, data)]
|
|
340
|
-
y_model = _repeat_flatten_list(y_model, len_x)
|
|
341
|
-
|
|
342
|
-
if len(y_model) == 1:
|
|
343
|
-
y_model = _repeat_flatten_list(y_model, len_x)
|
|
344
|
-
|
|
345
|
-
rows, cols = default_grid(length_plotters)
|
|
346
|
-
|
|
347
|
-
lmplot_kwargs = dict(
|
|
348
|
-
x=x,
|
|
349
|
-
y=y,
|
|
350
|
-
y_model=y_model,
|
|
351
|
-
y_hat=y_hat,
|
|
352
|
-
num_samples=num_samples,
|
|
353
|
-
kind_pp=kind_pp,
|
|
354
|
-
kind_model=kind_model,
|
|
355
|
-
length_plotters=length_plotters,
|
|
356
|
-
xjitter=xjitter,
|
|
357
|
-
rows=rows,
|
|
358
|
-
cols=cols,
|
|
359
|
-
y_kwargs=y_kwargs,
|
|
360
|
-
y_hat_plot_kwargs=y_hat_plot_kwargs,
|
|
361
|
-
y_hat_fill_kwargs=y_hat_fill_kwargs,
|
|
362
|
-
y_model_plot_kwargs=y_model_plot_kwargs,
|
|
363
|
-
y_model_fill_kwargs=y_model_fill_kwargs,
|
|
364
|
-
y_model_mean_kwargs=y_model_mean_kwargs,
|
|
365
|
-
backend_kwargs=backend_kwargs,
|
|
366
|
-
show=show,
|
|
367
|
-
figsize=figsize,
|
|
368
|
-
textsize=textsize,
|
|
369
|
-
axes=axes,
|
|
370
|
-
legend=legend,
|
|
371
|
-
grid=grid,
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
if backend is None:
|
|
375
|
-
backend = rcParams["plot.backend"]
|
|
376
|
-
backend = backend.lower()
|
|
377
|
-
|
|
378
|
-
plot = get_plotting_function("plot_lm", "lmplot", backend)
|
|
379
|
-
ax = plot(**lmplot_kwargs)
|
|
380
|
-
return ax
|