arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/parallelplot.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
"""Parallel coordinates plot showing posterior points with and without divergences marked."""
|
|
2
|
-
|
|
3
|
-
import numpy as np
|
|
4
|
-
from scipy.stats import rankdata
|
|
5
|
-
|
|
6
|
-
from ..data import convert_to_dataset
|
|
7
|
-
from ..labels import BaseLabeller
|
|
8
|
-
from ..sel_utils import xarray_to_ndarray
|
|
9
|
-
from ..rcparams import rcParams
|
|
10
|
-
from ..stats.stats_utils import stats_variance_2d as svar
|
|
11
|
-
from ..utils import _numba_var, _var_names, get_coords
|
|
12
|
-
from .plot_utils import get_plotting_function
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def plot_parallel(
|
|
16
|
-
data,
|
|
17
|
-
var_names=None,
|
|
18
|
-
filter_vars=None,
|
|
19
|
-
coords=None,
|
|
20
|
-
figsize=None,
|
|
21
|
-
textsize=None,
|
|
22
|
-
legend=True,
|
|
23
|
-
colornd="k",
|
|
24
|
-
colord="C1",
|
|
25
|
-
shadend=0.025,
|
|
26
|
-
labeller=None,
|
|
27
|
-
ax=None,
|
|
28
|
-
norm_method=None,
|
|
29
|
-
backend=None,
|
|
30
|
-
backend_config=None,
|
|
31
|
-
backend_kwargs=None,
|
|
32
|
-
show=None,
|
|
33
|
-
):
|
|
34
|
-
"""
|
|
35
|
-
Plot parallel coordinates plot showing posterior points with and without divergences.
|
|
36
|
-
|
|
37
|
-
Described by https://arxiv.org/abs/1709.01449
|
|
38
|
-
|
|
39
|
-
Parameters
|
|
40
|
-
----------
|
|
41
|
-
data: obj
|
|
42
|
-
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
43
|
-
refer to documentation of :func:`arviz.convert_to_dataset` for details
|
|
44
|
-
var_names: list of variable names
|
|
45
|
-
Variables to be plotted, if `None` all variables are plotted. Can be used to change the
|
|
46
|
-
order of the plotted variables. Prefix the variables by ``~`` when you want to exclude
|
|
47
|
-
them from the plot.
|
|
48
|
-
filter_vars: {None, "like", "regex"}, optional, default=None
|
|
49
|
-
If `None` (default), interpret var_names as the real variables names. If "like",
|
|
50
|
-
interpret var_names as substrings of the real variables names. If "regex",
|
|
51
|
-
interpret var_names as regular expressions on the real variables names. A la
|
|
52
|
-
``pandas.filter``.
|
|
53
|
-
coords: mapping, optional
|
|
54
|
-
Coordinates of ``var_names`` to be plotted.
|
|
55
|
-
Passed to :meth:`xarray.Dataset.sel`.
|
|
56
|
-
figsize: tuple
|
|
57
|
-
Figure size. If None it will be defined automatically.
|
|
58
|
-
textsize: float
|
|
59
|
-
Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
|
|
60
|
-
on ``figsize``.
|
|
61
|
-
legend: bool
|
|
62
|
-
Flag for plotting legend (defaults to True)
|
|
63
|
-
colornd: valid matplotlib color
|
|
64
|
-
color for non-divergent points. Defaults to 'k'
|
|
65
|
-
colord: valid matplotlib color
|
|
66
|
-
color for divergent points. Defaults to 'C1'
|
|
67
|
-
shadend: float
|
|
68
|
-
Alpha blending value for non-divergent points, between 0 (invisible) and 1 (opaque).
|
|
69
|
-
Defaults to .025
|
|
70
|
-
labeller : labeller instance, optional
|
|
71
|
-
Class providing the method ``make_label_vert`` to generate the labels in the plot.
|
|
72
|
-
Read the :ref:`label_guide` for more details and usage examples.
|
|
73
|
-
ax: axes, optional
|
|
74
|
-
Matplotlib axes or bokeh figures.
|
|
75
|
-
norm_method: str
|
|
76
|
-
Method for normalizing the data. Methods include normal, minmax and rank.
|
|
77
|
-
Defaults to none.
|
|
78
|
-
backend: str, optional
|
|
79
|
-
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
|
|
80
|
-
backend_config: dict, optional
|
|
81
|
-
Currently specifies the bounds to use for bokeh axes.
|
|
82
|
-
Defaults to value set in ``rcParams``.
|
|
83
|
-
backend_kwargs: bool, optional
|
|
84
|
-
These are kwargs specific to the backend being used, passed to
|
|
85
|
-
:func:`matplotlib.pyplot.subplots` or
|
|
86
|
-
:func:`bokeh.plotting.figure`.
|
|
87
|
-
show: bool, optional
|
|
88
|
-
Call backend show function.
|
|
89
|
-
|
|
90
|
-
Returns
|
|
91
|
-
-------
|
|
92
|
-
axes: matplotlib axes or bokeh figures
|
|
93
|
-
|
|
94
|
-
See Also
|
|
95
|
-
--------
|
|
96
|
-
plot_pair : Plot a scatter, kde and/or hexbin matrix with (optional) marginals on the diagonal.
|
|
97
|
-
plot_trace : Plot distribution (histogram or kernel density estimates) and sampled values
|
|
98
|
-
or rank plot
|
|
99
|
-
|
|
100
|
-
Examples
|
|
101
|
-
--------
|
|
102
|
-
Plot default parallel plot
|
|
103
|
-
|
|
104
|
-
.. plot::
|
|
105
|
-
:context: close-figs
|
|
106
|
-
|
|
107
|
-
>>> import arviz as az
|
|
108
|
-
>>> data = az.load_arviz_data('centered_eight')
|
|
109
|
-
>>> az.plot_parallel(data, var_names=["mu", "tau"])
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
Plot parallel plot with normalization
|
|
113
|
-
|
|
114
|
-
.. plot::
|
|
115
|
-
:context: close-figs
|
|
116
|
-
|
|
117
|
-
>>> az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="normal")
|
|
118
|
-
|
|
119
|
-
Plot parallel plot with minmax
|
|
120
|
-
|
|
121
|
-
.. plot::
|
|
122
|
-
:context: close-figs
|
|
123
|
-
|
|
124
|
-
>>> ax = az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="minmax")
|
|
125
|
-
>>> ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
|
|
126
|
-
|
|
127
|
-
Plot parallel plot with rank
|
|
128
|
-
|
|
129
|
-
.. plot::
|
|
130
|
-
:context: close-figs
|
|
131
|
-
|
|
132
|
-
>>> ax = az.plot_parallel(data, var_names=["theta", "tau", "mu"], norm_method="rank")
|
|
133
|
-
>>> ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
|
|
134
|
-
"""
|
|
135
|
-
if coords is None:
|
|
136
|
-
coords = {}
|
|
137
|
-
|
|
138
|
-
if labeller is None:
|
|
139
|
-
labeller = BaseLabeller()
|
|
140
|
-
|
|
141
|
-
# Get diverging draws and combine chains
|
|
142
|
-
divergent_data = convert_to_dataset(data, group="sample_stats")
|
|
143
|
-
_, diverging_mask = xarray_to_ndarray(
|
|
144
|
-
divergent_data,
|
|
145
|
-
var_names=("diverging",),
|
|
146
|
-
combined=True,
|
|
147
|
-
)
|
|
148
|
-
diverging_mask = np.squeeze(diverging_mask)
|
|
149
|
-
|
|
150
|
-
# Get posterior draws and combine chains
|
|
151
|
-
posterior_data = convert_to_dataset(data, group="posterior")
|
|
152
|
-
var_names = _var_names(var_names, posterior_data, filter_vars)
|
|
153
|
-
var_names, _posterior = xarray_to_ndarray(
|
|
154
|
-
get_coords(posterior_data, coords),
|
|
155
|
-
var_names=var_names,
|
|
156
|
-
combined=True,
|
|
157
|
-
label_fun=labeller.make_label_vert,
|
|
158
|
-
)
|
|
159
|
-
if len(var_names) < 2:
|
|
160
|
-
raise ValueError("Number of variables to be plotted must be 2 or greater.")
|
|
161
|
-
if norm_method is not None:
|
|
162
|
-
if norm_method == "normal":
|
|
163
|
-
mean = np.mean(_posterior, axis=1)
|
|
164
|
-
if _posterior.ndim <= 2:
|
|
165
|
-
standard_deviation = np.sqrt(_numba_var(svar, np.var, _posterior, axis=1))
|
|
166
|
-
else:
|
|
167
|
-
standard_deviation = np.std(_posterior, axis=1)
|
|
168
|
-
for i in range(0, np.shape(mean)[0]):
|
|
169
|
-
_posterior[i, :] = (_posterior[i, :] - mean[i]) / standard_deviation[i]
|
|
170
|
-
elif norm_method == "minmax":
|
|
171
|
-
min_elem = np.min(_posterior, axis=1)
|
|
172
|
-
max_elem = np.max(_posterior, axis=1)
|
|
173
|
-
for i in range(0, np.shape(min_elem)[0]):
|
|
174
|
-
_posterior[i, :] = ((_posterior[i, :]) - min_elem[i]) / (max_elem[i] - min_elem[i])
|
|
175
|
-
elif norm_method == "rank":
|
|
176
|
-
_posterior = rankdata(_posterior, axis=1, method="average")
|
|
177
|
-
else:
|
|
178
|
-
raise ValueError(f"{norm_method} is not supported. Use normal, minmax or rank.")
|
|
179
|
-
|
|
180
|
-
parallel_kwargs = dict(
|
|
181
|
-
ax=ax,
|
|
182
|
-
colornd=colornd,
|
|
183
|
-
colord=colord,
|
|
184
|
-
shadend=shadend,
|
|
185
|
-
diverging_mask=diverging_mask,
|
|
186
|
-
posterior=_posterior,
|
|
187
|
-
textsize=textsize,
|
|
188
|
-
var_names=var_names,
|
|
189
|
-
legend=legend,
|
|
190
|
-
figsize=figsize,
|
|
191
|
-
backend_kwargs=backend_kwargs,
|
|
192
|
-
backend_config=backend_config,
|
|
193
|
-
show=show,
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
if backend is None:
|
|
197
|
-
backend = rcParams["plot.backend"]
|
|
198
|
-
backend = backend.lower()
|
|
199
|
-
|
|
200
|
-
# TODO: Add backend kwargs
|
|
201
|
-
plot = get_plotting_function("plot_parallel", "parallelplot", backend)
|
|
202
|
-
ax = plot(**parallel_kwargs)
|
|
203
|
-
|
|
204
|
-
return ax
|