arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/plots/plot_utils.py
DELETED
|
@@ -1,599 +0,0 @@
|
|
|
1
|
-
"""Utilities for plotting."""
|
|
2
|
-
|
|
3
|
-
import importlib
|
|
4
|
-
import warnings
|
|
5
|
-
from typing import Any, Dict
|
|
6
|
-
|
|
7
|
-
import matplotlib as mpl
|
|
8
|
-
import numpy as np
|
|
9
|
-
import packaging
|
|
10
|
-
from matplotlib.colors import to_hex
|
|
11
|
-
from scipy.stats import mode, rankdata
|
|
12
|
-
from scipy.interpolate import CubicSpline
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
from ..rcparams import rcParams
|
|
16
|
-
from ..stats.density_utils import kde
|
|
17
|
-
from ..stats import hdi
|
|
18
|
-
|
|
19
|
-
KwargSpec = Dict[str, Any]
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def make_2d(ary):
|
|
23
|
-
"""Convert any array into a 2d numpy array.
|
|
24
|
-
|
|
25
|
-
In case the array is already more than 2 dimensional, will ravel the
|
|
26
|
-
dimensions after the first.
|
|
27
|
-
"""
|
|
28
|
-
dim_0, *_ = np.atleast_1d(ary).shape
|
|
29
|
-
return ary.reshape(dim_0, -1, order="F")
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def _scale_fig_size(figsize, textsize, rows=1, cols=1):
|
|
33
|
-
"""Scale figure properties according to rows and cols.
|
|
34
|
-
|
|
35
|
-
Parameters
|
|
36
|
-
----------
|
|
37
|
-
figsize : float or None
|
|
38
|
-
Size of figure in inches
|
|
39
|
-
textsize : float or None
|
|
40
|
-
fontsize
|
|
41
|
-
rows : int
|
|
42
|
-
Number of rows
|
|
43
|
-
cols : int
|
|
44
|
-
Number of columns
|
|
45
|
-
|
|
46
|
-
Returns
|
|
47
|
-
-------
|
|
48
|
-
figsize : float or None
|
|
49
|
-
Size of figure in inches
|
|
50
|
-
ax_labelsize : int
|
|
51
|
-
fontsize for axes label
|
|
52
|
-
titlesize : int
|
|
53
|
-
fontsize for title
|
|
54
|
-
xt_labelsize : int
|
|
55
|
-
fontsize for axes ticks
|
|
56
|
-
linewidth : int
|
|
57
|
-
linewidth
|
|
58
|
-
markersize : int
|
|
59
|
-
markersize
|
|
60
|
-
"""
|
|
61
|
-
params = mpl.rcParams
|
|
62
|
-
rc_width, rc_height = tuple(params["figure.figsize"])
|
|
63
|
-
rc_ax_labelsize = params["axes.labelsize"]
|
|
64
|
-
rc_titlesize = params["axes.titlesize"]
|
|
65
|
-
rc_xt_labelsize = params["xtick.labelsize"]
|
|
66
|
-
rc_linewidth = params["lines.linewidth"]
|
|
67
|
-
rc_markersize = params["lines.markersize"]
|
|
68
|
-
if isinstance(rc_ax_labelsize, str):
|
|
69
|
-
rc_ax_labelsize = 15
|
|
70
|
-
if isinstance(rc_titlesize, str):
|
|
71
|
-
rc_titlesize = 16
|
|
72
|
-
if isinstance(rc_xt_labelsize, str):
|
|
73
|
-
rc_xt_labelsize = 14
|
|
74
|
-
|
|
75
|
-
if figsize is None:
|
|
76
|
-
width, height = rc_width, rc_height
|
|
77
|
-
sff = 1 if (rows == cols == 1) else 1.15
|
|
78
|
-
width = width * cols * sff
|
|
79
|
-
height = height * rows * sff
|
|
80
|
-
else:
|
|
81
|
-
width, height = figsize
|
|
82
|
-
|
|
83
|
-
if textsize is not None:
|
|
84
|
-
scale_factor = textsize / rc_xt_labelsize
|
|
85
|
-
elif rows == cols == 1:
|
|
86
|
-
scale_factor = ((width * height) / (rc_width * rc_height)) ** 0.5
|
|
87
|
-
else:
|
|
88
|
-
scale_factor = 1
|
|
89
|
-
|
|
90
|
-
ax_labelsize = rc_ax_labelsize * scale_factor
|
|
91
|
-
titlesize = rc_titlesize * scale_factor
|
|
92
|
-
xt_labelsize = rc_xt_labelsize * scale_factor
|
|
93
|
-
linewidth = rc_linewidth * scale_factor
|
|
94
|
-
markersize = rc_markersize * scale_factor
|
|
95
|
-
|
|
96
|
-
return (width, height), ax_labelsize, titlesize, xt_labelsize, linewidth, markersize
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
def default_grid(n_items, grid=None, max_cols=4, min_cols=3): # noqa: D202
|
|
100
|
-
"""Make a grid for subplots.
|
|
101
|
-
|
|
102
|
-
Tries to get as close to sqrt(n_items) x sqrt(n_items) as it can,
|
|
103
|
-
but allows for custom logic
|
|
104
|
-
|
|
105
|
-
Parameters
|
|
106
|
-
----------
|
|
107
|
-
n_items : int
|
|
108
|
-
Number of panels required
|
|
109
|
-
grid : tuple
|
|
110
|
-
Number of rows and columns
|
|
111
|
-
max_cols : int
|
|
112
|
-
Maximum number of columns, inclusive
|
|
113
|
-
min_cols : int
|
|
114
|
-
Minimum number of columns, inclusive
|
|
115
|
-
|
|
116
|
-
Returns
|
|
117
|
-
-------
|
|
118
|
-
(int, int)
|
|
119
|
-
Rows and columns, so that rows * columns >= n_items
|
|
120
|
-
"""
|
|
121
|
-
|
|
122
|
-
if grid is None:
|
|
123
|
-
|
|
124
|
-
def in_bounds(val):
|
|
125
|
-
return np.clip(val, min_cols, max_cols)
|
|
126
|
-
|
|
127
|
-
if n_items <= max_cols:
|
|
128
|
-
return 1, n_items
|
|
129
|
-
ideal = in_bounds(round(n_items**0.5))
|
|
130
|
-
|
|
131
|
-
for offset in (0, 1, -1, 2, -2):
|
|
132
|
-
cols = in_bounds(ideal + offset)
|
|
133
|
-
rows, extra = divmod(n_items, cols)
|
|
134
|
-
if extra == 0:
|
|
135
|
-
return rows, cols
|
|
136
|
-
return n_items // ideal + 1, ideal
|
|
137
|
-
else:
|
|
138
|
-
rows, cols = grid
|
|
139
|
-
if rows * cols < n_items:
|
|
140
|
-
raise ValueError("The number of rows times columns is less than the number of subplots")
|
|
141
|
-
if (rows * cols) - n_items >= cols:
|
|
142
|
-
warnings.warn("The number of rows times columns is larger than necessary")
|
|
143
|
-
return rows, cols
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
def format_sig_figs(value, default=None):
|
|
147
|
-
"""Get a default number of significant figures.
|
|
148
|
-
|
|
149
|
-
Gives the integer part or `default`, whichever is bigger.
|
|
150
|
-
|
|
151
|
-
Examples
|
|
152
|
-
--------
|
|
153
|
-
0.1234 --> 0.12
|
|
154
|
-
1.234 --> 1.2
|
|
155
|
-
12.34 --> 12
|
|
156
|
-
123.4 --> 123
|
|
157
|
-
"""
|
|
158
|
-
if default is None:
|
|
159
|
-
default = 2
|
|
160
|
-
if value == 0:
|
|
161
|
-
return 1
|
|
162
|
-
return max(int(np.log10(np.abs(value))) + 1, default)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
def round_num(n, round_to):
|
|
166
|
-
"""
|
|
167
|
-
Return a string representing a number with `round_to` significant figures.
|
|
168
|
-
|
|
169
|
-
Parameters
|
|
170
|
-
----------
|
|
171
|
-
n : float
|
|
172
|
-
number to round
|
|
173
|
-
round_to : int
|
|
174
|
-
number of significant figures
|
|
175
|
-
"""
|
|
176
|
-
sig_figs = format_sig_figs(n, round_to)
|
|
177
|
-
return "{n:.{sig_figs}g}".format(n=n, sig_figs=sig_figs)
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def color_from_dim(dataarray, dim_name):
|
|
181
|
-
"""Return colors and color mapping of a DataArray using coord values as color code.
|
|
182
|
-
|
|
183
|
-
Parameters
|
|
184
|
-
----------
|
|
185
|
-
dataarray : xarray.DataArray
|
|
186
|
-
dim_name : str
|
|
187
|
-
dimension whose coordinates will be used as color code.
|
|
188
|
-
|
|
189
|
-
Returns
|
|
190
|
-
-------
|
|
191
|
-
colors : array of floats
|
|
192
|
-
Array of colors (as floats for use with a cmap) for each element in the dataarray.
|
|
193
|
-
color_mapping : mapping coord_value -> float
|
|
194
|
-
Mapping from coord values to corresponding color
|
|
195
|
-
"""
|
|
196
|
-
present_dims = dataarray.dims
|
|
197
|
-
coord_values = dataarray[dim_name].values
|
|
198
|
-
unique_coords = set(coord_values)
|
|
199
|
-
color_mapping = {coord: num / len(unique_coords) for num, coord in enumerate(unique_coords)}
|
|
200
|
-
if len(present_dims) > 1:
|
|
201
|
-
multi_coords = dataarray.coords.to_index()
|
|
202
|
-
coord_idx = present_dims.index(dim_name)
|
|
203
|
-
colors = [color_mapping[coord[coord_idx]] for coord in multi_coords]
|
|
204
|
-
else:
|
|
205
|
-
colors = [color_mapping[coord] for coord in coord_values]
|
|
206
|
-
return colors, color_mapping
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def vectorized_to_hex(c_values, keep_alpha=False):
|
|
210
|
-
"""Convert a color (including vector of colors) to hex.
|
|
211
|
-
|
|
212
|
-
Parameters
|
|
213
|
-
----------
|
|
214
|
-
c: Matplotlib color
|
|
215
|
-
|
|
216
|
-
keep_alpha: boolean
|
|
217
|
-
to select if alpha values should be kept in the final hex values.
|
|
218
|
-
|
|
219
|
-
Returns
|
|
220
|
-
-------
|
|
221
|
-
rgba_hex : vector of hex values
|
|
222
|
-
"""
|
|
223
|
-
try:
|
|
224
|
-
hex_color = to_hex(c_values, keep_alpha)
|
|
225
|
-
|
|
226
|
-
except ValueError:
|
|
227
|
-
hex_color = [to_hex(color, keep_alpha) for color in c_values]
|
|
228
|
-
return hex_color
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def format_coords_as_labels(dataarray, skip_dims=None):
|
|
232
|
-
"""Format 1d or multi-d dataarray coords as strings.
|
|
233
|
-
|
|
234
|
-
Parameters
|
|
235
|
-
----------
|
|
236
|
-
dataarray : xarray.DataArray
|
|
237
|
-
DataArray whose coordinates will be converted to labels.
|
|
238
|
-
skip_dims : str of list_like, optional
|
|
239
|
-
Dimensions whose values should not be included in the labels
|
|
240
|
-
"""
|
|
241
|
-
if skip_dims is None:
|
|
242
|
-
coord_labels = dataarray.coords.to_index()
|
|
243
|
-
else:
|
|
244
|
-
coord_labels = dataarray.coords.to_index().droplevel(skip_dims).drop_duplicates()
|
|
245
|
-
coord_labels = coord_labels.values
|
|
246
|
-
if isinstance(coord_labels[0], tuple):
|
|
247
|
-
fmt = ", ".join(["{}" for _ in coord_labels[0]])
|
|
248
|
-
return np.array([fmt.format(*x) for x in coord_labels])
|
|
249
|
-
return np.array([f"{s}" for s in coord_labels])
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
def set_xticklabels(ax, coord_labels):
|
|
253
|
-
"""Set xticklabels to label list using Matplotlib default formatter."""
|
|
254
|
-
ax.xaxis.get_major_locator().set_params(nbins=9, steps=[1, 2, 5, 10])
|
|
255
|
-
xticks = ax.get_xticks().astype(np.int64)
|
|
256
|
-
xticks = xticks[(xticks >= 0) & (xticks < len(coord_labels))]
|
|
257
|
-
if len(xticks) > len(coord_labels):
|
|
258
|
-
ax.set_xticks(np.arange(len(coord_labels)))
|
|
259
|
-
ax.set_xticklabels(coord_labels)
|
|
260
|
-
else:
|
|
261
|
-
ax.set_xticks(xticks)
|
|
262
|
-
ax.set_xticklabels(coord_labels[xticks])
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def filter_plotters_list(plotters, plot_kind):
|
|
266
|
-
"""Cut list of plotters so that it is at most of length "plot.max_subplots"."""
|
|
267
|
-
max_plots = rcParams["plot.max_subplots"]
|
|
268
|
-
max_plots = len(plotters) if max_plots is None else max_plots
|
|
269
|
-
if len(plotters) > max_plots:
|
|
270
|
-
warnings.warn(
|
|
271
|
-
"rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
|
|
272
|
-
"of variables to plot ({len_plotters}) in {plot_kind}, generating only "
|
|
273
|
-
"{max_plots} plots".format(
|
|
274
|
-
max_plots=max_plots, len_plotters=len(plotters), plot_kind=plot_kind
|
|
275
|
-
),
|
|
276
|
-
UserWarning,
|
|
277
|
-
)
|
|
278
|
-
return plotters[:max_plots]
|
|
279
|
-
return plotters
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
def get_plotting_function(plot_name, plot_module, backend):
|
|
283
|
-
"""Return plotting function for correct backend."""
|
|
284
|
-
_backend = {
|
|
285
|
-
"mpl": "matplotlib",
|
|
286
|
-
"bokeh": "bokeh",
|
|
287
|
-
"matplotlib": "matplotlib",
|
|
288
|
-
}
|
|
289
|
-
|
|
290
|
-
if backend is None:
|
|
291
|
-
backend = rcParams["plot.backend"]
|
|
292
|
-
backend = backend.lower()
|
|
293
|
-
|
|
294
|
-
try:
|
|
295
|
-
backend = _backend[backend]
|
|
296
|
-
except KeyError as err:
|
|
297
|
-
raise KeyError(
|
|
298
|
-
f"Backend {backend} is not implemented. Try backend in {set(_backend.values())}"
|
|
299
|
-
) from err
|
|
300
|
-
|
|
301
|
-
if backend == "bokeh":
|
|
302
|
-
try:
|
|
303
|
-
import bokeh
|
|
304
|
-
|
|
305
|
-
assert packaging.version.parse(bokeh.__version__) >= packaging.version.parse("1.4.0")
|
|
306
|
-
|
|
307
|
-
except (ImportError, AssertionError) as err:
|
|
308
|
-
raise ImportError(
|
|
309
|
-
"'bokeh' backend needs Bokeh (1.4.0+) installed. Please upgrade or install"
|
|
310
|
-
) from err
|
|
311
|
-
|
|
312
|
-
# Perform import of plotting method
|
|
313
|
-
# TODO: Convert module import to top level for all plots
|
|
314
|
-
module = importlib.import_module(f"arviz.plots.backends.{backend}.{plot_module}")
|
|
315
|
-
|
|
316
|
-
plotting_method = getattr(module, plot_name)
|
|
317
|
-
|
|
318
|
-
return plotting_method
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
def calculate_point_estimate(point_estimate, values, bw="default", circular=False, skipna=False):
|
|
322
|
-
"""Validate and calculate the point estimate.
|
|
323
|
-
|
|
324
|
-
Parameters
|
|
325
|
-
----------
|
|
326
|
-
point_estimate : Optional[str]
|
|
327
|
-
Plot point estimate per variable. Values should be 'mean', 'median', 'mode' or None.
|
|
328
|
-
Defaults to 'auto' i.e. it falls back to default set in rcParams.
|
|
329
|
-
values : 1-d array
|
|
330
|
-
bw: Optional[float or str]
|
|
331
|
-
If numeric, indicates the bandwidth and must be positive.
|
|
332
|
-
If str, indicates the method to estimate the bandwidth and must be
|
|
333
|
-
one of "scott", "silverman", "isj" or "experimental" when `circular` is False
|
|
334
|
-
and "taylor" (for now) when `circular` is True.
|
|
335
|
-
Defaults to "default" which means "experimental" when variable is not circular
|
|
336
|
-
and "taylor" when it is.
|
|
337
|
-
circular: Optional[bool]
|
|
338
|
-
If True, it interprets the values passed are from a circular variable measured in radians
|
|
339
|
-
and a circular KDE is used. Only valid for 1D KDE. Defaults to False.
|
|
340
|
-
skipna=True,
|
|
341
|
-
If true ignores nan values when computing the hdi. Defaults to false.
|
|
342
|
-
|
|
343
|
-
Returns
|
|
344
|
-
-------
|
|
345
|
-
point_value : float
|
|
346
|
-
best estimate of data distribution
|
|
347
|
-
"""
|
|
348
|
-
point_value = None
|
|
349
|
-
if point_estimate == "auto":
|
|
350
|
-
point_estimate = rcParams["plot.point_estimate"]
|
|
351
|
-
elif point_estimate not in ("mean", "median", "mode", None):
|
|
352
|
-
raise ValueError(
|
|
353
|
-
f"Point estimate should be 'mean', 'median', 'mode' or None, not {point_estimate}"
|
|
354
|
-
)
|
|
355
|
-
if point_estimate == "mean":
|
|
356
|
-
point_value = np.nanmean(values) if skipna else np.mean(values)
|
|
357
|
-
elif point_estimate == "mode":
|
|
358
|
-
if values.dtype.kind == "f":
|
|
359
|
-
if bw == "default":
|
|
360
|
-
bw = "taylor" if circular else "experimental"
|
|
361
|
-
x, density = kde(values, circular=circular, bw=bw)
|
|
362
|
-
point_value = x[np.argmax(density)]
|
|
363
|
-
else:
|
|
364
|
-
point_value = int(mode(values).mode)
|
|
365
|
-
elif point_estimate == "median":
|
|
366
|
-
point_value = np.nanmedian(values) if skipna else np.median(values)
|
|
367
|
-
return point_value
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
def plot_point_interval(
|
|
371
|
-
ax,
|
|
372
|
-
values,
|
|
373
|
-
point_estimate,
|
|
374
|
-
hdi_prob,
|
|
375
|
-
quartiles,
|
|
376
|
-
linewidth,
|
|
377
|
-
markersize,
|
|
378
|
-
markercolor,
|
|
379
|
-
marker,
|
|
380
|
-
rotated,
|
|
381
|
-
intervalcolor,
|
|
382
|
-
backend="matplotlib",
|
|
383
|
-
):
|
|
384
|
-
"""Plot point intervals.
|
|
385
|
-
|
|
386
|
-
Translates the data and represents them as point and interval summaries.
|
|
387
|
-
|
|
388
|
-
Parameters
|
|
389
|
-
----------
|
|
390
|
-
ax : axes
|
|
391
|
-
Matplotlib axes
|
|
392
|
-
values : array-like
|
|
393
|
-
Values to plot
|
|
394
|
-
point_estimate : str
|
|
395
|
-
Plot point estimate per variable.
|
|
396
|
-
linewidth : int
|
|
397
|
-
Line width throughout.
|
|
398
|
-
quartiles : bool
|
|
399
|
-
If True then the quartile interval will be plotted with the HDI.
|
|
400
|
-
markersize : int
|
|
401
|
-
Markersize throughout.
|
|
402
|
-
markercolor: string
|
|
403
|
-
Color of the marker.
|
|
404
|
-
marker: string
|
|
405
|
-
Shape of the marker.
|
|
406
|
-
hdi_prob : float
|
|
407
|
-
Valid only when point_interval is True. Plots HDI for chosen percentage of density.
|
|
408
|
-
rotated : bool
|
|
409
|
-
Whether to rotate the dot plot by 90 degrees.
|
|
410
|
-
intervalcolor : string
|
|
411
|
-
Color of the interval.
|
|
412
|
-
backend : string, optional
|
|
413
|
-
Matplotlib or Bokeh.
|
|
414
|
-
"""
|
|
415
|
-
endpoint = (1 - hdi_prob) / 2
|
|
416
|
-
if quartiles:
|
|
417
|
-
qlist_interval = [endpoint, 0.25, 0.75, 1 - endpoint]
|
|
418
|
-
else:
|
|
419
|
-
qlist_interval = [endpoint, 1 - endpoint]
|
|
420
|
-
quantiles_interval = np.quantile(values, qlist_interval)
|
|
421
|
-
|
|
422
|
-
quantiles_interval[0], quantiles_interval[-1] = hdi(
|
|
423
|
-
values.flatten(), hdi_prob, multimodal=False
|
|
424
|
-
)
|
|
425
|
-
mid = len(quantiles_interval) // 2
|
|
426
|
-
param_iter = zip(np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid))
|
|
427
|
-
|
|
428
|
-
if backend == "matplotlib":
|
|
429
|
-
for width, j in param_iter:
|
|
430
|
-
if rotated:
|
|
431
|
-
ax.vlines(
|
|
432
|
-
0,
|
|
433
|
-
quantiles_interval[j],
|
|
434
|
-
quantiles_interval[-(j + 1)],
|
|
435
|
-
linewidth=width,
|
|
436
|
-
color=intervalcolor,
|
|
437
|
-
)
|
|
438
|
-
else:
|
|
439
|
-
ax.hlines(
|
|
440
|
-
0,
|
|
441
|
-
quantiles_interval[j],
|
|
442
|
-
quantiles_interval[-(j + 1)],
|
|
443
|
-
linewidth=width,
|
|
444
|
-
color=intervalcolor,
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
if point_estimate:
|
|
448
|
-
point_value = calculate_point_estimate(point_estimate, values)
|
|
449
|
-
if rotated:
|
|
450
|
-
ax.plot(
|
|
451
|
-
0,
|
|
452
|
-
point_value,
|
|
453
|
-
marker,
|
|
454
|
-
markersize=markersize,
|
|
455
|
-
color=markercolor,
|
|
456
|
-
)
|
|
457
|
-
else:
|
|
458
|
-
ax.plot(
|
|
459
|
-
point_value,
|
|
460
|
-
0,
|
|
461
|
-
marker,
|
|
462
|
-
markersize=markersize,
|
|
463
|
-
color=markercolor,
|
|
464
|
-
)
|
|
465
|
-
else:
|
|
466
|
-
for width, j in param_iter:
|
|
467
|
-
if rotated:
|
|
468
|
-
ax.line(
|
|
469
|
-
[0, 0],
|
|
470
|
-
[quantiles_interval[j], quantiles_interval[-(j + 1)]],
|
|
471
|
-
line_width=width,
|
|
472
|
-
color=intervalcolor,
|
|
473
|
-
)
|
|
474
|
-
else:
|
|
475
|
-
ax.line(
|
|
476
|
-
[quantiles_interval[j], quantiles_interval[-(j + 1)]],
|
|
477
|
-
[0, 0],
|
|
478
|
-
line_width=width,
|
|
479
|
-
color=intervalcolor,
|
|
480
|
-
)
|
|
481
|
-
|
|
482
|
-
if point_estimate:
|
|
483
|
-
point_value = calculate_point_estimate(point_estimate, values)
|
|
484
|
-
if rotated:
|
|
485
|
-
ax.scatter(
|
|
486
|
-
x=0,
|
|
487
|
-
y=point_value,
|
|
488
|
-
marker="circle",
|
|
489
|
-
size=markersize,
|
|
490
|
-
fill_color=markercolor,
|
|
491
|
-
)
|
|
492
|
-
else:
|
|
493
|
-
ax.scatter(
|
|
494
|
-
x=point_value,
|
|
495
|
-
y=0,
|
|
496
|
-
marker="circle",
|
|
497
|
-
size=markersize,
|
|
498
|
-
fill_color=markercolor,
|
|
499
|
-
)
|
|
500
|
-
|
|
501
|
-
return ax
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
def is_valid_quantile(value):
|
|
505
|
-
"""Check if value is a number between 0 and 1."""
|
|
506
|
-
try:
|
|
507
|
-
value = float(value)
|
|
508
|
-
return 0 < value < 1
|
|
509
|
-
except ValueError:
|
|
510
|
-
return False
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
def sample_reference_distribution(dist, shape):
|
|
514
|
-
"""Generate samples from a scipy distribution with a given shape."""
|
|
515
|
-
x_ss = []
|
|
516
|
-
densities = []
|
|
517
|
-
dist_rvs = dist.rvs(size=shape)
|
|
518
|
-
for idx in range(shape[1]):
|
|
519
|
-
x_s, density = kde(dist_rvs[:, idx])
|
|
520
|
-
x_ss.append(x_s)
|
|
521
|
-
densities.append(density)
|
|
522
|
-
return np.array(x_ss).T, np.array(densities).T
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
def set_bokeh_circular_ticks_labels(ax, hist, labels):
|
|
526
|
-
"""Place ticks and ticklabels on Bokeh's circular histogram."""
|
|
527
|
-
ticks = np.linspace(-np.pi, np.pi, len(labels), endpoint=False)
|
|
528
|
-
ax.annular_wedge(
|
|
529
|
-
x=0,
|
|
530
|
-
y=0,
|
|
531
|
-
inner_radius=0,
|
|
532
|
-
outer_radius=np.max(hist) * 1.1,
|
|
533
|
-
start_angle=ticks,
|
|
534
|
-
end_angle=ticks,
|
|
535
|
-
line_color="grey",
|
|
536
|
-
)
|
|
537
|
-
|
|
538
|
-
radii_circles = np.linspace(0, np.max(hist) * 1.1, 4)
|
|
539
|
-
ax.scatter(0, 0, marker="circle", radius=radii_circles, fill_color=None, line_color="grey")
|
|
540
|
-
|
|
541
|
-
offset = np.max(hist * 1.05) * 0.15
|
|
542
|
-
ticks_labels_pos_1 = np.max(hist * 1.05)
|
|
543
|
-
ticks_labels_pos_2 = ticks_labels_pos_1 * np.sqrt(2) / 2
|
|
544
|
-
|
|
545
|
-
ax.text(
|
|
546
|
-
[
|
|
547
|
-
ticks_labels_pos_1 + offset,
|
|
548
|
-
ticks_labels_pos_2 + offset,
|
|
549
|
-
0,
|
|
550
|
-
-ticks_labels_pos_2 - offset,
|
|
551
|
-
-ticks_labels_pos_1 - offset,
|
|
552
|
-
-ticks_labels_pos_2 - offset,
|
|
553
|
-
0,
|
|
554
|
-
ticks_labels_pos_2 + offset,
|
|
555
|
-
],
|
|
556
|
-
[
|
|
557
|
-
0,
|
|
558
|
-
ticks_labels_pos_2 + offset / 2,
|
|
559
|
-
ticks_labels_pos_1 + offset,
|
|
560
|
-
ticks_labels_pos_2 + offset / 2,
|
|
561
|
-
0,
|
|
562
|
-
-ticks_labels_pos_2 - offset,
|
|
563
|
-
-ticks_labels_pos_1 - offset,
|
|
564
|
-
-ticks_labels_pos_2 - offset,
|
|
565
|
-
],
|
|
566
|
-
text=labels,
|
|
567
|
-
text_align="center",
|
|
568
|
-
)
|
|
569
|
-
|
|
570
|
-
return ax
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
def compute_ranks(ary):
|
|
574
|
-
"""Compute ranks for continuous and discrete variables."""
|
|
575
|
-
if ary.dtype.kind == "i":
|
|
576
|
-
ary_shape = ary.shape
|
|
577
|
-
ary = ary.flatten()
|
|
578
|
-
min_ary, max_ary = min(ary), max(ary)
|
|
579
|
-
x = np.linspace(min_ary, max_ary, len(ary))
|
|
580
|
-
csi = CubicSpline(x, ary)
|
|
581
|
-
ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape)
|
|
582
|
-
ranks = rankdata(ary, method="average").reshape(ary.shape)
|
|
583
|
-
|
|
584
|
-
return ranks
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
def _init_kwargs_dict(kwargs):
|
|
588
|
-
"""Initialize kwargs dict.
|
|
589
|
-
|
|
590
|
-
If the input is a dictionary, it returns
|
|
591
|
-
a copy of the dictionary, otherwise it
|
|
592
|
-
returns an empty dictionary.
|
|
593
|
-
|
|
594
|
-
Parameters
|
|
595
|
-
----------
|
|
596
|
-
kwargs : dict or None
|
|
597
|
-
kwargs dict to initialize
|
|
598
|
-
"""
|
|
599
|
-
return {} if kwargs is None else kwargs.copy()
|