arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
|
@@ -1,656 +0,0 @@
|
|
|
1
|
-
"""Matplotlib forestplot."""
|
|
2
|
-
|
|
3
|
-
from collections import OrderedDict, defaultdict
|
|
4
|
-
from itertools import tee
|
|
5
|
-
|
|
6
|
-
import matplotlib.pyplot as plt
|
|
7
|
-
import numpy as np
|
|
8
|
-
from matplotlib.colors import to_rgba
|
|
9
|
-
from matplotlib.lines import Line2D
|
|
10
|
-
|
|
11
|
-
from ....stats import hdi
|
|
12
|
-
from ....stats.density_utils import get_bins, histogram, kde
|
|
13
|
-
from ....stats.diagnostics import _ess, _rhat
|
|
14
|
-
from ....sel_utils import xarray_var_iter
|
|
15
|
-
from ...plot_utils import _scale_fig_size
|
|
16
|
-
from . import backend_kwarg_defaults, backend_show
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def pairwise(iterable):
|
|
20
|
-
"""From itertools cookbook. [a, b, c, ...] -> (a, b), (b, c), ..."""
|
|
21
|
-
first, second = tee(iterable)
|
|
22
|
-
next(second, None)
|
|
23
|
-
return zip(first, second)
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def plot_forest(
|
|
27
|
-
ax,
|
|
28
|
-
datasets,
|
|
29
|
-
var_names,
|
|
30
|
-
model_names,
|
|
31
|
-
combined,
|
|
32
|
-
combine_dims,
|
|
33
|
-
colors,
|
|
34
|
-
figsize,
|
|
35
|
-
width_ratios,
|
|
36
|
-
linewidth,
|
|
37
|
-
markersize,
|
|
38
|
-
kind,
|
|
39
|
-
ncols,
|
|
40
|
-
hdi_prob,
|
|
41
|
-
quartiles,
|
|
42
|
-
rope,
|
|
43
|
-
ridgeplot_overlap,
|
|
44
|
-
ridgeplot_alpha,
|
|
45
|
-
ridgeplot_kind,
|
|
46
|
-
ridgeplot_truncate,
|
|
47
|
-
ridgeplot_quantiles,
|
|
48
|
-
textsize,
|
|
49
|
-
legend,
|
|
50
|
-
labeller,
|
|
51
|
-
ess,
|
|
52
|
-
r_hat,
|
|
53
|
-
backend_kwargs,
|
|
54
|
-
backend_config, # pylint: disable=unused-argument
|
|
55
|
-
show,
|
|
56
|
-
):
|
|
57
|
-
"""Matplotlib forest plot."""
|
|
58
|
-
plot_handler = PlotHandler(
|
|
59
|
-
datasets,
|
|
60
|
-
var_names=var_names,
|
|
61
|
-
model_names=model_names,
|
|
62
|
-
combined=combined,
|
|
63
|
-
combine_dims=combine_dims,
|
|
64
|
-
colors=colors,
|
|
65
|
-
labeller=labeller,
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
if figsize is None:
|
|
69
|
-
if kind == "ridgeplot":
|
|
70
|
-
figsize = (min(14, sum(width_ratios) * 4), plot_handler.fig_height() * 1.2)
|
|
71
|
-
else:
|
|
72
|
-
figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height())
|
|
73
|
-
|
|
74
|
-
(figsize, _, titlesize, xt_labelsize, auto_linewidth, auto_markersize) = _scale_fig_size(
|
|
75
|
-
figsize, textsize, 1.1, 1
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
if linewidth is None:
|
|
79
|
-
linewidth = auto_linewidth
|
|
80
|
-
|
|
81
|
-
if markersize is None:
|
|
82
|
-
markersize = auto_markersize
|
|
83
|
-
|
|
84
|
-
if backend_kwargs is None:
|
|
85
|
-
backend_kwargs = {}
|
|
86
|
-
|
|
87
|
-
backend_kwargs = {
|
|
88
|
-
**backend_kwarg_defaults(),
|
|
89
|
-
**backend_kwargs,
|
|
90
|
-
}
|
|
91
|
-
|
|
92
|
-
if ax is None:
|
|
93
|
-
_, axes = plt.subplots(
|
|
94
|
-
nrows=1,
|
|
95
|
-
ncols=ncols,
|
|
96
|
-
figsize=figsize,
|
|
97
|
-
gridspec_kw={"width_ratios": width_ratios},
|
|
98
|
-
sharey=True,
|
|
99
|
-
**backend_kwargs,
|
|
100
|
-
)
|
|
101
|
-
else:
|
|
102
|
-
axes = ax
|
|
103
|
-
|
|
104
|
-
axes = np.atleast_1d(axes)
|
|
105
|
-
if kind == "forestplot":
|
|
106
|
-
plot_handler.forestplot(
|
|
107
|
-
hdi_prob,
|
|
108
|
-
quartiles,
|
|
109
|
-
xt_labelsize,
|
|
110
|
-
titlesize,
|
|
111
|
-
linewidth,
|
|
112
|
-
markersize,
|
|
113
|
-
axes[0],
|
|
114
|
-
rope,
|
|
115
|
-
)
|
|
116
|
-
elif kind == "ridgeplot":
|
|
117
|
-
plot_handler.ridgeplot(
|
|
118
|
-
hdi_prob,
|
|
119
|
-
ridgeplot_overlap,
|
|
120
|
-
linewidth,
|
|
121
|
-
markersize,
|
|
122
|
-
ridgeplot_alpha,
|
|
123
|
-
ridgeplot_kind,
|
|
124
|
-
ridgeplot_truncate,
|
|
125
|
-
ridgeplot_quantiles,
|
|
126
|
-
axes[0],
|
|
127
|
-
)
|
|
128
|
-
else:
|
|
129
|
-
raise TypeError(
|
|
130
|
-
f"Argument 'kind' must be one of 'forestplot' " f"or 'ridgeplot' (you provided {kind})"
|
|
131
|
-
)
|
|
132
|
-
|
|
133
|
-
idx = 1
|
|
134
|
-
if ess:
|
|
135
|
-
plot_handler.plot_neff(axes[idx], xt_labelsize, titlesize, markersize)
|
|
136
|
-
idx += 1
|
|
137
|
-
|
|
138
|
-
if r_hat:
|
|
139
|
-
plot_handler.plot_rhat(axes[idx], xt_labelsize, titlesize, markersize)
|
|
140
|
-
idx += 1
|
|
141
|
-
|
|
142
|
-
for ax_ in axes:
|
|
143
|
-
if kind == "ridgeplot":
|
|
144
|
-
ax_.grid(False)
|
|
145
|
-
else:
|
|
146
|
-
ax_.grid(False, axis="y")
|
|
147
|
-
# Remove ticklines on y-axes
|
|
148
|
-
ax_.tick_params(axis="y", left=False, right=False)
|
|
149
|
-
|
|
150
|
-
for loc, spine in ax_.spines.items():
|
|
151
|
-
if loc in ["left", "right"]:
|
|
152
|
-
spine.set_visible(False)
|
|
153
|
-
|
|
154
|
-
if len(plot_handler.data) > 1:
|
|
155
|
-
plot_handler.make_bands(ax_)
|
|
156
|
-
|
|
157
|
-
labels, ticks = plot_handler.labels_and_ticks()
|
|
158
|
-
axes[0].set_yticks(ticks)
|
|
159
|
-
axes[0].set_yticklabels(labels)
|
|
160
|
-
all_plotters = list(plot_handler.plotters.values())
|
|
161
|
-
y_max = plot_handler.y_max() - all_plotters[-1].group_offset
|
|
162
|
-
if kind == "ridgeplot": # space at the top
|
|
163
|
-
y_max += ridgeplot_overlap
|
|
164
|
-
axes[0].set_ylim(-all_plotters[0].group_offset, y_max)
|
|
165
|
-
if legend:
|
|
166
|
-
plot_handler.legend(ax=axes[0])
|
|
167
|
-
|
|
168
|
-
if backend_show(show):
|
|
169
|
-
plt.show()
|
|
170
|
-
|
|
171
|
-
return axes
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
class PlotHandler:
|
|
175
|
-
"""Class to handle logic from ForestPlot."""
|
|
176
|
-
|
|
177
|
-
# pylint: disable=inconsistent-return-statements
|
|
178
|
-
|
|
179
|
-
def __init__(self, datasets, var_names, model_names, combined, combine_dims, colors, labeller):
|
|
180
|
-
self.data = datasets
|
|
181
|
-
|
|
182
|
-
if model_names is None:
|
|
183
|
-
if len(self.data) > 1:
|
|
184
|
-
model_names = [f"Model {idx}" for idx, _ in enumerate(self.data)]
|
|
185
|
-
else:
|
|
186
|
-
model_names = [None]
|
|
187
|
-
elif len(model_names) != len(self.data):
|
|
188
|
-
raise ValueError("The number of model names does not match the number of models")
|
|
189
|
-
|
|
190
|
-
self.model_names = list(reversed(model_names)) # y-values are upside down
|
|
191
|
-
|
|
192
|
-
if var_names is None:
|
|
193
|
-
if len(self.data) > 1:
|
|
194
|
-
self.var_names = list(
|
|
195
|
-
set().union(*[OrderedDict(datum.data_vars) for datum in self.data])
|
|
196
|
-
)
|
|
197
|
-
else:
|
|
198
|
-
self.var_names = list(
|
|
199
|
-
reversed(*[OrderedDict(datum.data_vars) for datum in self.data])
|
|
200
|
-
)
|
|
201
|
-
else:
|
|
202
|
-
self.var_names = list(reversed(var_names)) # y-values are upside down
|
|
203
|
-
|
|
204
|
-
self.combined = combined
|
|
205
|
-
self.combine_dims = combine_dims
|
|
206
|
-
|
|
207
|
-
if colors == "cycle":
|
|
208
|
-
# TODO: Use matplotlib prop cycle instead
|
|
209
|
-
colors = [f"C{idx}" for idx, _ in enumerate(self.data)]
|
|
210
|
-
elif isinstance(colors, str):
|
|
211
|
-
colors = [colors for _ in self.data]
|
|
212
|
-
|
|
213
|
-
self.colors = list(reversed(colors)) # y-values are upside down
|
|
214
|
-
self.labeller = labeller
|
|
215
|
-
|
|
216
|
-
self.plotters = self.make_plotters()
|
|
217
|
-
|
|
218
|
-
def make_plotters(self):
|
|
219
|
-
"""Initialize an object for each variable to be plotted."""
|
|
220
|
-
plotters, y = {}, 0
|
|
221
|
-
for var_name in self.var_names:
|
|
222
|
-
plotters[var_name] = VarHandler(
|
|
223
|
-
var_name,
|
|
224
|
-
self.data,
|
|
225
|
-
y,
|
|
226
|
-
model_names=self.model_names,
|
|
227
|
-
combined=self.combined,
|
|
228
|
-
combine_dims=self.combine_dims,
|
|
229
|
-
colors=self.colors,
|
|
230
|
-
labeller=self.labeller,
|
|
231
|
-
)
|
|
232
|
-
y = plotters[var_name].y_max()
|
|
233
|
-
return plotters
|
|
234
|
-
|
|
235
|
-
def labels_and_ticks(self):
|
|
236
|
-
"""Collect labels and ticks from plotters."""
|
|
237
|
-
val = self.plotters.values()
|
|
238
|
-
|
|
239
|
-
def label_idxs():
|
|
240
|
-
labels, idxs = [], []
|
|
241
|
-
for plotter in val:
|
|
242
|
-
sub_labels, sub_idxs, _, _ = plotter.labels_ticks_and_vals()
|
|
243
|
-
labels_to_idxs = defaultdict(list)
|
|
244
|
-
for label, idx in zip(sub_labels, sub_idxs):
|
|
245
|
-
labels_to_idxs[label].append(idx)
|
|
246
|
-
sub_idxs = []
|
|
247
|
-
sub_labels = []
|
|
248
|
-
for label, all_idx in labels_to_idxs.items():
|
|
249
|
-
sub_labels.append(label)
|
|
250
|
-
sub_idxs.append(np.mean([j for j in all_idx]))
|
|
251
|
-
labels.append(sub_labels)
|
|
252
|
-
idxs.append(sub_idxs)
|
|
253
|
-
return np.concatenate(labels), np.concatenate(idxs)
|
|
254
|
-
|
|
255
|
-
return label_idxs()
|
|
256
|
-
|
|
257
|
-
def legend(self, ax):
|
|
258
|
-
"""Add legend with colorcoded model info."""
|
|
259
|
-
handles = [Line2D([], [], color=c) for c in self.colors]
|
|
260
|
-
ax.legend(handles=handles, labels=self.model_names)
|
|
261
|
-
|
|
262
|
-
def display_multiple_ropes(self, rope, ax, y, linewidth, var_name, selection):
|
|
263
|
-
"""Display ROPE when more than one interval is provided."""
|
|
264
|
-
for sel in rope.get(var_name, []):
|
|
265
|
-
# pylint: disable=line-too-long
|
|
266
|
-
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
|
|
267
|
-
vals = sel["rope"]
|
|
268
|
-
ax.plot(
|
|
269
|
-
vals,
|
|
270
|
-
(y + 0.05, y + 0.05),
|
|
271
|
-
lw=linewidth * 2,
|
|
272
|
-
color="C2",
|
|
273
|
-
solid_capstyle="round",
|
|
274
|
-
zorder=0,
|
|
275
|
-
alpha=0.7,
|
|
276
|
-
)
|
|
277
|
-
return ax
|
|
278
|
-
|
|
279
|
-
def ridgeplot(
|
|
280
|
-
self,
|
|
281
|
-
hdi_prob,
|
|
282
|
-
mult,
|
|
283
|
-
linewidth,
|
|
284
|
-
markersize,
|
|
285
|
-
alpha,
|
|
286
|
-
ridgeplot_kind,
|
|
287
|
-
ridgeplot_truncate,
|
|
288
|
-
ridgeplot_quantiles,
|
|
289
|
-
ax,
|
|
290
|
-
):
|
|
291
|
-
"""Draw ridgeplot for each plotter.
|
|
292
|
-
|
|
293
|
-
Parameters
|
|
294
|
-
----------
|
|
295
|
-
hdi_prob : float
|
|
296
|
-
Probability for the highest density interval.
|
|
297
|
-
mult : float
|
|
298
|
-
How much to multiply height by. Set this to greater than 1 to have some overlap.
|
|
299
|
-
linewidth : float
|
|
300
|
-
Width of line on border of ridges
|
|
301
|
-
markersize : float
|
|
302
|
-
Size of marker in center of forestplot line
|
|
303
|
-
alpha : float
|
|
304
|
-
Transparency of ridges
|
|
305
|
-
ridgeplot_kind : string
|
|
306
|
-
By default ("auto") continuous variables are plotted using KDEs and discrete ones using
|
|
307
|
-
histograms. To override this use "hist" to plot histograms and "density" for KDEs
|
|
308
|
-
ridgeplot_truncate: bool
|
|
309
|
-
Whether to truncate densities according to the value of hdi_prop. Defaults to True
|
|
310
|
-
ridgeplot_quantiles: list
|
|
311
|
-
Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
|
|
312
|
-
Defaults to None.
|
|
313
|
-
ax : Axes
|
|
314
|
-
Axes to draw on
|
|
315
|
-
"""
|
|
316
|
-
if alpha is None:
|
|
317
|
-
alpha = 1.0
|
|
318
|
-
zorder = 0
|
|
319
|
-
for plotter in self.plotters.values():
|
|
320
|
-
for x, y_min, y_max, hdi_, y_q, color in plotter.ridgeplot(
|
|
321
|
-
hdi_prob, mult, ridgeplot_kind
|
|
322
|
-
):
|
|
323
|
-
if alpha == 0:
|
|
324
|
-
border = color
|
|
325
|
-
facecolor = "None"
|
|
326
|
-
else:
|
|
327
|
-
border = "k"
|
|
328
|
-
if x.dtype.kind == "i":
|
|
329
|
-
if ridgeplot_truncate:
|
|
330
|
-
facecolor = to_rgba(color, alpha)
|
|
331
|
-
y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
332
|
-
x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
333
|
-
else:
|
|
334
|
-
facecolor = [
|
|
335
|
-
to_rgba(color, alpha) if ci else "None"
|
|
336
|
-
for ci in ((x >= hdi_[0]) & (x <= hdi_[1]))
|
|
337
|
-
]
|
|
338
|
-
y_min = np.ones_like(x) * y_min
|
|
339
|
-
ax.bar(
|
|
340
|
-
x,
|
|
341
|
-
y_max - y_min,
|
|
342
|
-
bottom=y_min,
|
|
343
|
-
linewidth=linewidth,
|
|
344
|
-
ec=border,
|
|
345
|
-
color=facecolor,
|
|
346
|
-
alpha=None,
|
|
347
|
-
zorder=zorder,
|
|
348
|
-
)
|
|
349
|
-
else:
|
|
350
|
-
tr_x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
351
|
-
tr_y_min = np.ones_like(tr_x) * y_min
|
|
352
|
-
tr_y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
353
|
-
y_min = np.ones_like(x) * y_min
|
|
354
|
-
if ridgeplot_truncate:
|
|
355
|
-
ax.plot(
|
|
356
|
-
tr_x, tr_y_max, "-", linewidth=linewidth, color=border, zorder=zorder
|
|
357
|
-
)
|
|
358
|
-
ax.plot(
|
|
359
|
-
tr_x, tr_y_min, "-", linewidth=linewidth, color=border, zorder=zorder
|
|
360
|
-
)
|
|
361
|
-
else:
|
|
362
|
-
ax.plot(x, y_max, "-", linewidth=linewidth, color=border, zorder=zorder)
|
|
363
|
-
ax.plot(x, y_min, "-", linewidth=linewidth, color=border, zorder=zorder)
|
|
364
|
-
ax.fill_between(
|
|
365
|
-
tr_x, tr_y_max, tr_y_min, alpha=alpha, color=color, zorder=zorder
|
|
366
|
-
)
|
|
367
|
-
|
|
368
|
-
if ridgeplot_quantiles is not None:
|
|
369
|
-
quantiles = [x[np.sum(y_q < quant)] for quant in ridgeplot_quantiles]
|
|
370
|
-
ax.plot(
|
|
371
|
-
quantiles,
|
|
372
|
-
np.ones_like(quantiles) * y_min[0],
|
|
373
|
-
"d",
|
|
374
|
-
mfc=border,
|
|
375
|
-
mec=border,
|
|
376
|
-
ms=markersize,
|
|
377
|
-
)
|
|
378
|
-
zorder -= 1
|
|
379
|
-
return ax
|
|
380
|
-
|
|
381
|
-
def forestplot(
|
|
382
|
-
self, hdi_prob, quartiles, xt_labelsize, titlesize, linewidth, markersize, ax, rope
|
|
383
|
-
):
|
|
384
|
-
"""Draw forestplot for each plotter.
|
|
385
|
-
|
|
386
|
-
Parameters
|
|
387
|
-
----------
|
|
388
|
-
hdi_prob : float
|
|
389
|
-
Probability for the highest density interval. Width of each line.
|
|
390
|
-
quartiles : bool
|
|
391
|
-
Whether to mark quartiles
|
|
392
|
-
xt_textsize : float
|
|
393
|
-
Size of tick text
|
|
394
|
-
titlesize : float
|
|
395
|
-
Size of title text
|
|
396
|
-
linewidth : float
|
|
397
|
-
Width of forestplot line
|
|
398
|
-
markersize : float
|
|
399
|
-
Size of marker in center of forestplot line
|
|
400
|
-
ax : Axes
|
|
401
|
-
Axes to draw on
|
|
402
|
-
"""
|
|
403
|
-
# Quantiles to be calculated
|
|
404
|
-
endpoint = 100 * (1 - hdi_prob) / 2
|
|
405
|
-
if quartiles:
|
|
406
|
-
qlist = [endpoint, 25, 50, 75, 100 - endpoint]
|
|
407
|
-
else:
|
|
408
|
-
qlist = [endpoint, 50, 100 - endpoint]
|
|
409
|
-
|
|
410
|
-
for plotter in self.plotters.values():
|
|
411
|
-
for y, selection, values, color in plotter.treeplot(qlist, hdi_prob):
|
|
412
|
-
if isinstance(rope, dict):
|
|
413
|
-
self.display_multiple_ropes(rope, ax, y, linewidth, plotter.var_name, selection)
|
|
414
|
-
|
|
415
|
-
mid = len(values) // 2
|
|
416
|
-
param_iter = zip(
|
|
417
|
-
np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid)
|
|
418
|
-
)
|
|
419
|
-
for width, j in param_iter:
|
|
420
|
-
ax.hlines(y, values[j], values[-(j + 1)], linewidth=width, color=color)
|
|
421
|
-
ax.plot(
|
|
422
|
-
values[mid],
|
|
423
|
-
y,
|
|
424
|
-
"o",
|
|
425
|
-
mfc=ax.get_facecolor(),
|
|
426
|
-
markersize=markersize * 0.75,
|
|
427
|
-
color=color,
|
|
428
|
-
)
|
|
429
|
-
ax.tick_params(labelsize=xt_labelsize)
|
|
430
|
-
ax.set_title(f"{hdi_prob:.1%} HDI", fontsize=titlesize, wrap=True)
|
|
431
|
-
if rope is None or isinstance(rope, dict):
|
|
432
|
-
return
|
|
433
|
-
elif len(rope) == 2:
|
|
434
|
-
ax.axvspan(rope[0], rope[1], 0, self.y_max(), color="C2", alpha=0.5)
|
|
435
|
-
else:
|
|
436
|
-
raise ValueError(
|
|
437
|
-
"Argument `rope` must be None, a dictionary like"
|
|
438
|
-
'{"var_name": {"rope": (lo, hi)}}, or an '
|
|
439
|
-
"iterable of length 2"
|
|
440
|
-
)
|
|
441
|
-
return ax
|
|
442
|
-
|
|
443
|
-
def plot_neff(self, ax, xt_labelsize, titlesize, markersize):
|
|
444
|
-
"""Draw effective n for each plotter."""
|
|
445
|
-
for plotter in self.plotters.values():
|
|
446
|
-
for y, ess, color in plotter.ess():
|
|
447
|
-
if ess is not None:
|
|
448
|
-
ax.plot(
|
|
449
|
-
ess,
|
|
450
|
-
y,
|
|
451
|
-
"o",
|
|
452
|
-
color=color,
|
|
453
|
-
clip_on=False,
|
|
454
|
-
markersize=markersize,
|
|
455
|
-
markeredgecolor="k",
|
|
456
|
-
)
|
|
457
|
-
ax.set_xlim(left=0)
|
|
458
|
-
ax.set_title("ess", fontsize=titlesize, wrap=True)
|
|
459
|
-
ax.tick_params(labelsize=xt_labelsize)
|
|
460
|
-
return ax
|
|
461
|
-
|
|
462
|
-
def plot_rhat(self, ax, xt_labelsize, titlesize, markersize):
|
|
463
|
-
"""Draw r-hat for each plotter."""
|
|
464
|
-
for plotter in self.plotters.values():
|
|
465
|
-
for y, r_hat, color in plotter.r_hat():
|
|
466
|
-
if r_hat is not None:
|
|
467
|
-
ax.plot(r_hat, y, "o", color=color, markersize=markersize, markeredgecolor="k")
|
|
468
|
-
ax.set_xlim(left=0.9, right=2.1)
|
|
469
|
-
ax.set_xticks([1, 2])
|
|
470
|
-
ax.tick_params(labelsize=xt_labelsize)
|
|
471
|
-
ax.set_title("r_hat", fontsize=titlesize, wrap=True)
|
|
472
|
-
return ax
|
|
473
|
-
|
|
474
|
-
def make_bands(self, ax):
|
|
475
|
-
"""Draw shaded horizontal bands for each plotter."""
|
|
476
|
-
y_vals, y_prev, is_zero = [0], None, False
|
|
477
|
-
prev_color_index = 0
|
|
478
|
-
for plotter in self.plotters.values():
|
|
479
|
-
for y, *_, color in plotter.iterator():
|
|
480
|
-
if self.colors.index(color) < prev_color_index:
|
|
481
|
-
if not is_zero and y_prev is not None:
|
|
482
|
-
y_vals.append((y + y_prev) * 0.5)
|
|
483
|
-
is_zero = True
|
|
484
|
-
else:
|
|
485
|
-
is_zero = False
|
|
486
|
-
prev_color_index = self.colors.index(color)
|
|
487
|
-
y_prev = y
|
|
488
|
-
|
|
489
|
-
offset = plotter.group_offset # pylint: disable=undefined-loop-variable
|
|
490
|
-
|
|
491
|
-
y_vals.append(y_prev + offset)
|
|
492
|
-
for idx, (y_start, y_stop) in enumerate(pairwise(y_vals)):
|
|
493
|
-
ax.axhspan(y_start, y_stop, color="k", alpha=0.1 * (idx % 2))
|
|
494
|
-
return ax
|
|
495
|
-
|
|
496
|
-
def fig_height(self):
|
|
497
|
-
"""Figure out the height of this plot."""
|
|
498
|
-
# hand-tuned
|
|
499
|
-
return (
|
|
500
|
-
4
|
|
501
|
-
+ len(self.data) * len(self.var_names)
|
|
502
|
-
- 1
|
|
503
|
-
+ 0.1 * sum(1 for j in self.plotters.values() for _ in j.iterator())
|
|
504
|
-
)
|
|
505
|
-
|
|
506
|
-
def y_max(self):
|
|
507
|
-
"""Get maximum y value for the plot."""
|
|
508
|
-
return max(p.y_max() for p in self.plotters.values())
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
# pylint: disable=too-many-instance-attributes
|
|
512
|
-
class VarHandler:
|
|
513
|
-
"""Handle individual variable logic."""
|
|
514
|
-
|
|
515
|
-
def __init__(
|
|
516
|
-
self, var_name, data, y_start, model_names, combined, combine_dims, colors, labeller
|
|
517
|
-
):
|
|
518
|
-
self.var_name = var_name
|
|
519
|
-
self.data = data
|
|
520
|
-
self.y_start = y_start
|
|
521
|
-
self.model_names = model_names
|
|
522
|
-
self.combined = combined
|
|
523
|
-
self.combine_dims = combine_dims
|
|
524
|
-
self.colors = colors
|
|
525
|
-
self.labeller = labeller
|
|
526
|
-
self.model_color = dict(zip(self.model_names, self.colors))
|
|
527
|
-
max_chains = max(datum.chain.max().values for datum in data)
|
|
528
|
-
self.chain_offset = len(data) * 0.45 / max(1, max_chains)
|
|
529
|
-
self.var_offset = 1.5 * self.chain_offset
|
|
530
|
-
self.group_offset = 2 * self.var_offset
|
|
531
|
-
|
|
532
|
-
def iterator(self):
|
|
533
|
-
"""Iterate over models and chains for each variable."""
|
|
534
|
-
if self.combined:
|
|
535
|
-
grouped_data = [[(0, datum)] for datum in self.data]
|
|
536
|
-
skip_dims = self.combine_dims.union({"chain"})
|
|
537
|
-
else:
|
|
538
|
-
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
|
|
539
|
-
skip_dims = self.combine_dims
|
|
540
|
-
|
|
541
|
-
label_dict = OrderedDict()
|
|
542
|
-
selection_list = []
|
|
543
|
-
for name, grouped_datum in zip(self.model_names, grouped_data):
|
|
544
|
-
for _, sub_data in grouped_datum:
|
|
545
|
-
datum_iter = xarray_var_iter(
|
|
546
|
-
sub_data.squeeze(),
|
|
547
|
-
var_names=[self.var_name],
|
|
548
|
-
skip_dims=skip_dims,
|
|
549
|
-
reverse_selections=True,
|
|
550
|
-
)
|
|
551
|
-
datum_list = list(datum_iter)
|
|
552
|
-
for _, selection, isel, values in datum_list:
|
|
553
|
-
selection_list.append(selection)
|
|
554
|
-
if not selection or not len(selection_list) % len(datum_list):
|
|
555
|
-
var_name = self.var_name
|
|
556
|
-
else:
|
|
557
|
-
var_name = ""
|
|
558
|
-
label = self.labeller.make_label_flat(var_name, selection, isel)
|
|
559
|
-
if label not in label_dict:
|
|
560
|
-
label_dict[label] = OrderedDict()
|
|
561
|
-
if name not in label_dict[label]:
|
|
562
|
-
label_dict[label][name] = []
|
|
563
|
-
label_dict[label][name].append(values)
|
|
564
|
-
|
|
565
|
-
y = self.y_start
|
|
566
|
-
for idx, (label, model_data) in enumerate(label_dict.items()):
|
|
567
|
-
for model_name, value_list in model_data.items():
|
|
568
|
-
row_label = self.labeller.make_model_label(model_name, label)
|
|
569
|
-
for values in value_list:
|
|
570
|
-
yield y, row_label, label, selection_list[idx], values, self.model_color[
|
|
571
|
-
model_name
|
|
572
|
-
]
|
|
573
|
-
y += self.chain_offset
|
|
574
|
-
y += self.var_offset
|
|
575
|
-
y += self.group_offset
|
|
576
|
-
|
|
577
|
-
def labels_ticks_and_vals(self):
|
|
578
|
-
"""Get labels, ticks, values, and colors for the variable."""
|
|
579
|
-
y_ticks = defaultdict(list)
|
|
580
|
-
for y, label, _, _, vals, color in self.iterator():
|
|
581
|
-
y_ticks[label].append((y, vals, color))
|
|
582
|
-
labels, ticks, vals, colors = [], [], [], []
|
|
583
|
-
for label, all_data in y_ticks.items():
|
|
584
|
-
for data in all_data:
|
|
585
|
-
labels.append(label)
|
|
586
|
-
ticks.append(data[0])
|
|
587
|
-
vals.append(np.array(data[1]))
|
|
588
|
-
colors.append(data[2]) # the colors are all the same
|
|
589
|
-
return labels, ticks, vals, colors
|
|
590
|
-
|
|
591
|
-
def treeplot(self, qlist, hdi_prob):
|
|
592
|
-
"""Get data for each treeplot for the variable."""
|
|
593
|
-
for y, _, _, selection, values, color in self.iterator():
|
|
594
|
-
ntiles = np.percentile(values.flatten(), qlist)
|
|
595
|
-
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
|
|
596
|
-
yield y, selection, ntiles, color
|
|
597
|
-
|
|
598
|
-
def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
|
|
599
|
-
"""Get data for each ridgeplot for the variable."""
|
|
600
|
-
xvals, hdi_vals, yvals, pdfs, pdfs_q, colors = [], [], [], [], [], []
|
|
601
|
-
for y, *_, values, color in self.iterator():
|
|
602
|
-
yvals.append(y)
|
|
603
|
-
colors.append(color)
|
|
604
|
-
values = values.flatten()
|
|
605
|
-
values = values[np.isfinite(values)]
|
|
606
|
-
|
|
607
|
-
if hdi_prob != 1:
|
|
608
|
-
hdi_ = hdi(values, hdi_prob, multimodal=False)
|
|
609
|
-
else:
|
|
610
|
-
hdi_ = min(values), max(values)
|
|
611
|
-
|
|
612
|
-
if ridgeplot_kind == "auto":
|
|
613
|
-
kind = "hist" if np.all(np.mod(values, 1) == 0) else "density"
|
|
614
|
-
else:
|
|
615
|
-
kind = ridgeplot_kind
|
|
616
|
-
|
|
617
|
-
if kind == "hist":
|
|
618
|
-
_, density, x = histogram(values, bins=get_bins(values))
|
|
619
|
-
x = x[:-1]
|
|
620
|
-
elif kind == "density":
|
|
621
|
-
x, density = kde(values)
|
|
622
|
-
|
|
623
|
-
density_q = density.cumsum() / density.sum()
|
|
624
|
-
|
|
625
|
-
xvals.append(x)
|
|
626
|
-
pdfs.append(density)
|
|
627
|
-
pdfs_q.append(density_q)
|
|
628
|
-
hdi_vals.append(hdi_)
|
|
629
|
-
|
|
630
|
-
scaling = max(np.max(j) for j in pdfs)
|
|
631
|
-
for y, x, hdi_val, pdf, pdf_q, color in zip(yvals, xvals, hdi_vals, pdfs, pdfs_q, colors):
|
|
632
|
-
yield x, y, mult * pdf / scaling + y, hdi_val, pdf_q, color
|
|
633
|
-
|
|
634
|
-
def ess(self):
|
|
635
|
-
"""Get effective n data for the variable."""
|
|
636
|
-
_, y_vals, values, colors = self.labels_ticks_and_vals()
|
|
637
|
-
for y, value, color in zip(y_vals, values, colors):
|
|
638
|
-
yield y, _ess(value), color
|
|
639
|
-
|
|
640
|
-
def r_hat(self):
|
|
641
|
-
"""Get rhat data for the variable."""
|
|
642
|
-
_, y_vals, values, colors = self.labels_ticks_and_vals()
|
|
643
|
-
for y, value, color in zip(y_vals, values, colors):
|
|
644
|
-
if value.ndim != 2 or value.shape[0] < 2:
|
|
645
|
-
yield y, None, color
|
|
646
|
-
else:
|
|
647
|
-
yield y, _rhat(value), color
|
|
648
|
-
|
|
649
|
-
def y_max(self):
|
|
650
|
-
"""Get max y value for the variable."""
|
|
651
|
-
end_y = max(y for y, *_ in self.iterator())
|
|
652
|
-
|
|
653
|
-
if self.combined:
|
|
654
|
-
end_y += self.group_offset
|
|
655
|
-
|
|
656
|
-
return end_y + 2 * self.group_offset
|
|
@@ -1,48 +0,0 @@
|
|
|
1
|
-
"""Matplotlib hdiplot."""
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
from matplotlib import _pylab_helpers
|
|
5
|
-
|
|
6
|
-
from ...plot_utils import _scale_fig_size, vectorized_to_hex
|
|
7
|
-
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backend_kwargs, show):
|
|
11
|
-
"""Matplotlib HDI plot."""
|
|
12
|
-
if backend_kwargs is None:
|
|
13
|
-
backend_kwargs = {}
|
|
14
|
-
|
|
15
|
-
backend_kwargs = {
|
|
16
|
-
**backend_kwarg_defaults(),
|
|
17
|
-
**backend_kwargs,
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
|
|
21
|
-
plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
|
|
22
|
-
plot_kwargs.setdefault("alpha", 0)
|
|
23
|
-
|
|
24
|
-
fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
|
|
25
|
-
fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
|
|
26
|
-
fill_kwargs.setdefault("alpha", 0.5)
|
|
27
|
-
|
|
28
|
-
figsize, *_ = _scale_fig_size(figsize, None)
|
|
29
|
-
backend_kwargs.setdefault("figsize", figsize)
|
|
30
|
-
backend_kwargs["squeeze"] = True
|
|
31
|
-
|
|
32
|
-
if ax is None:
|
|
33
|
-
fig_manager = _pylab_helpers.Gcf.get_active()
|
|
34
|
-
if fig_manager is not None:
|
|
35
|
-
ax = fig_manager.canvas.figure.gca()
|
|
36
|
-
else:
|
|
37
|
-
_, ax = create_axes_grid(
|
|
38
|
-
1,
|
|
39
|
-
backend_kwargs=backend_kwargs,
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
ax.plot(x_data, y_data, **plot_kwargs)
|
|
43
|
-
ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], **fill_kwargs)
|
|
44
|
-
|
|
45
|
-
if backend_show(show):
|
|
46
|
-
plt.show()
|
|
47
|
-
|
|
48
|
-
return ax
|