arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -357
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.1.dist-info/METADATA +0 -263
- arviz-0.23.1.dist-info/RECORD +0 -183
- arviz-0.23.1.dist-info/top_level.txt +0 -1
|
@@ -1,772 +0,0 @@
|
|
|
1
|
-
# pylint: disable=all
|
|
2
|
-
"""Bokeh forestplot."""
|
|
3
|
-
from collections import OrderedDict, defaultdict
|
|
4
|
-
from itertools import cycle, tee
|
|
5
|
-
|
|
6
|
-
import bokeh.plotting as bkp
|
|
7
|
-
import matplotlib.pyplot as plt
|
|
8
|
-
import numpy as np
|
|
9
|
-
from bokeh.models import Band, ColumnDataSource, DataRange1d
|
|
10
|
-
from bokeh.models.annotations import Title, Legend
|
|
11
|
-
from bokeh.models.tickers import FixedTicker
|
|
12
|
-
|
|
13
|
-
from ....sel_utils import xarray_var_iter
|
|
14
|
-
from ....rcparams import rcParams
|
|
15
|
-
from ....stats import hdi
|
|
16
|
-
from ....stats.density_utils import get_bins, histogram, kde
|
|
17
|
-
from ....stats.diagnostics import _ess, _rhat
|
|
18
|
-
from ...plot_utils import _scale_fig_size
|
|
19
|
-
from .. import show_layout
|
|
20
|
-
from . import backend_kwarg_defaults
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def pairwise(iterable):
|
|
24
|
-
"""From itertools cookbook. [a, b, c, ...] -> (a, b), (b, c), ..."""
|
|
25
|
-
first, second = tee(iterable)
|
|
26
|
-
next(second, None)
|
|
27
|
-
return zip(first, second)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def plot_forest(
|
|
31
|
-
ax,
|
|
32
|
-
datasets,
|
|
33
|
-
var_names,
|
|
34
|
-
model_names,
|
|
35
|
-
combined,
|
|
36
|
-
combine_dims,
|
|
37
|
-
colors,
|
|
38
|
-
figsize,
|
|
39
|
-
width_ratios,
|
|
40
|
-
linewidth,
|
|
41
|
-
markersize,
|
|
42
|
-
kind,
|
|
43
|
-
ncols,
|
|
44
|
-
hdi_prob,
|
|
45
|
-
quartiles,
|
|
46
|
-
rope,
|
|
47
|
-
ridgeplot_overlap,
|
|
48
|
-
ridgeplot_alpha,
|
|
49
|
-
ridgeplot_kind,
|
|
50
|
-
ridgeplot_truncate,
|
|
51
|
-
ridgeplot_quantiles,
|
|
52
|
-
textsize,
|
|
53
|
-
legend,
|
|
54
|
-
labeller,
|
|
55
|
-
ess,
|
|
56
|
-
r_hat,
|
|
57
|
-
backend_config,
|
|
58
|
-
backend_kwargs,
|
|
59
|
-
show,
|
|
60
|
-
):
|
|
61
|
-
"""Bokeh forest plot."""
|
|
62
|
-
plot_handler = PlotHandler(
|
|
63
|
-
datasets,
|
|
64
|
-
var_names=var_names,
|
|
65
|
-
model_names=model_names,
|
|
66
|
-
combined=combined,
|
|
67
|
-
combine_dims=combine_dims,
|
|
68
|
-
colors=colors,
|
|
69
|
-
labeller=labeller,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
if figsize is None:
|
|
73
|
-
if kind == "ridgeplot":
|
|
74
|
-
figsize = (min(14, sum(width_ratios) * 3), plot_handler.fig_height() * 3)
|
|
75
|
-
else:
|
|
76
|
-
figsize = (min(12, sum(width_ratios) * 2), plot_handler.fig_height())
|
|
77
|
-
|
|
78
|
-
(figsize, _, _, _, auto_linewidth, auto_markersize) = _scale_fig_size(figsize, textsize, 1.1, 1)
|
|
79
|
-
|
|
80
|
-
if linewidth is None:
|
|
81
|
-
linewidth = auto_linewidth
|
|
82
|
-
|
|
83
|
-
if markersize is None:
|
|
84
|
-
markersize = auto_markersize
|
|
85
|
-
|
|
86
|
-
if backend_config is None:
|
|
87
|
-
backend_config = {}
|
|
88
|
-
|
|
89
|
-
backend_config = {
|
|
90
|
-
**backend_kwarg_defaults(
|
|
91
|
-
("bounds_x_range", "plot.bokeh.bounds_x_range"),
|
|
92
|
-
("bounds_y_range", "plot.bokeh.bounds_y_range"),
|
|
93
|
-
),
|
|
94
|
-
**backend_config,
|
|
95
|
-
}
|
|
96
|
-
|
|
97
|
-
if backend_kwargs is None:
|
|
98
|
-
backend_kwargs = {}
|
|
99
|
-
|
|
100
|
-
backend_kwargs = {
|
|
101
|
-
**backend_kwarg_defaults(
|
|
102
|
-
("dpi", "plot.bokeh.figure.dpi"),
|
|
103
|
-
),
|
|
104
|
-
**backend_kwargs,
|
|
105
|
-
}
|
|
106
|
-
dpi = backend_kwargs.pop("dpi")
|
|
107
|
-
|
|
108
|
-
if ax is None:
|
|
109
|
-
axes = []
|
|
110
|
-
|
|
111
|
-
for i, width_r in zip(range(ncols), width_ratios):
|
|
112
|
-
backend_kwargs_i = backend_kwargs.copy()
|
|
113
|
-
backend_kwargs_i.setdefault("height", int(figsize[1] * dpi))
|
|
114
|
-
backend_kwargs_i.setdefault(
|
|
115
|
-
"width", int(figsize[0] * (width_r / sum(width_ratios)) * dpi * 1.25)
|
|
116
|
-
)
|
|
117
|
-
ax = bkp.figure(
|
|
118
|
-
**backend_kwargs_i,
|
|
119
|
-
)
|
|
120
|
-
if i == 0:
|
|
121
|
-
backend_kwargs.setdefault("y_range", ax.y_range)
|
|
122
|
-
axes.append(ax)
|
|
123
|
-
else:
|
|
124
|
-
axes = ax
|
|
125
|
-
|
|
126
|
-
axes = np.atleast_2d(axes)
|
|
127
|
-
|
|
128
|
-
plotted = defaultdict(list)
|
|
129
|
-
|
|
130
|
-
if kind == "forestplot":
|
|
131
|
-
plot_handler.forestplot(
|
|
132
|
-
hdi_prob,
|
|
133
|
-
quartiles,
|
|
134
|
-
linewidth,
|
|
135
|
-
markersize,
|
|
136
|
-
axes[0, 0],
|
|
137
|
-
rope,
|
|
138
|
-
plotted,
|
|
139
|
-
)
|
|
140
|
-
elif kind == "ridgeplot":
|
|
141
|
-
plot_handler.ridgeplot(
|
|
142
|
-
hdi_prob,
|
|
143
|
-
ridgeplot_overlap,
|
|
144
|
-
linewidth,
|
|
145
|
-
markersize,
|
|
146
|
-
ridgeplot_alpha,
|
|
147
|
-
ridgeplot_kind,
|
|
148
|
-
ridgeplot_truncate,
|
|
149
|
-
ridgeplot_quantiles,
|
|
150
|
-
axes[0, 0],
|
|
151
|
-
plotted,
|
|
152
|
-
)
|
|
153
|
-
else:
|
|
154
|
-
raise TypeError(
|
|
155
|
-
f"Argument 'kind' must be one of 'forestplot' or 'ridgeplot' (you provided {kind})"
|
|
156
|
-
)
|
|
157
|
-
|
|
158
|
-
idx = 1
|
|
159
|
-
if ess:
|
|
160
|
-
plotted_ess = defaultdict(list)
|
|
161
|
-
plot_handler.plot_neff(axes[0, idx], markersize, plotted_ess)
|
|
162
|
-
if legend:
|
|
163
|
-
plot_handler.legend(axes[0, idx], plotted_ess)
|
|
164
|
-
idx += 1
|
|
165
|
-
|
|
166
|
-
if r_hat:
|
|
167
|
-
plotted_r_hat = defaultdict(list)
|
|
168
|
-
plot_handler.plot_rhat(axes[0, idx], markersize, plotted_r_hat)
|
|
169
|
-
if legend:
|
|
170
|
-
plot_handler.legend(axes[0, idx], plotted_r_hat)
|
|
171
|
-
idx += 1
|
|
172
|
-
|
|
173
|
-
all_plotters = list(plot_handler.plotters.values())
|
|
174
|
-
y_max = plot_handler.y_max() - all_plotters[-1].group_offset
|
|
175
|
-
if kind == "ridgeplot": # space at the top
|
|
176
|
-
y_max += ridgeplot_overlap
|
|
177
|
-
|
|
178
|
-
for i, ax_ in enumerate(axes.ravel()):
|
|
179
|
-
if kind == "ridgeplot":
|
|
180
|
-
ax_.xgrid.grid_line_color = None
|
|
181
|
-
ax_.ygrid.grid_line_color = None
|
|
182
|
-
else:
|
|
183
|
-
ax_.ygrid.grid_line_color = None
|
|
184
|
-
|
|
185
|
-
if i != 0:
|
|
186
|
-
ax_.yaxis.visible = False
|
|
187
|
-
|
|
188
|
-
ax_.outline_line_color = None
|
|
189
|
-
ax_.x_range = DataRange1d(bounds=backend_config["bounds_x_range"], min_interval=1)
|
|
190
|
-
ax_.y_range = DataRange1d(bounds=backend_config["bounds_y_range"], min_interval=2)
|
|
191
|
-
|
|
192
|
-
ax_.y_range._property_values["start"] = -all_plotters[ # pylint: disable=protected-access
|
|
193
|
-
0
|
|
194
|
-
].group_offset
|
|
195
|
-
ax_.y_range._property_values["end"] = y_max # pylint: disable=protected-access
|
|
196
|
-
|
|
197
|
-
labels, ticks = plot_handler.labels_and_ticks()
|
|
198
|
-
ticks = [int(tick) if (tick).is_integer() else tick for tick in ticks]
|
|
199
|
-
|
|
200
|
-
axes[0, 0].yaxis.ticker = FixedTicker(ticks=ticks)
|
|
201
|
-
axes[0, 0].yaxis.major_label_overrides = dict(zip(map(str, ticks), map(str, labels)))
|
|
202
|
-
|
|
203
|
-
if legend:
|
|
204
|
-
plot_handler.legend(axes[0, 0], plotted)
|
|
205
|
-
show_layout(axes, show)
|
|
206
|
-
|
|
207
|
-
return axes
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
class PlotHandler:
|
|
211
|
-
"""Class to handle logic from ForestPlot."""
|
|
212
|
-
|
|
213
|
-
# pylint: disable=inconsistent-return-statements
|
|
214
|
-
|
|
215
|
-
def __init__(self, datasets, var_names, model_names, combined, combine_dims, colors, labeller):
|
|
216
|
-
self.data = datasets
|
|
217
|
-
|
|
218
|
-
if model_names is None:
|
|
219
|
-
if len(self.data) > 1:
|
|
220
|
-
model_names = [f"Model {idx}" for idx, _ in enumerate(self.data)]
|
|
221
|
-
else:
|
|
222
|
-
model_names = [""]
|
|
223
|
-
elif len(model_names) != len(self.data):
|
|
224
|
-
raise ValueError("The number of model names does not match the number of models")
|
|
225
|
-
|
|
226
|
-
self.model_names = list(reversed(model_names)) # y-values are upside down
|
|
227
|
-
|
|
228
|
-
if var_names is None:
|
|
229
|
-
if len(self.data) > 1:
|
|
230
|
-
self.var_names = list(
|
|
231
|
-
set().union(*[OrderedDict(datum.data_vars) for datum in self.data])
|
|
232
|
-
)
|
|
233
|
-
else:
|
|
234
|
-
self.var_names = list(
|
|
235
|
-
reversed(*[OrderedDict(datum.data_vars) for datum in self.data])
|
|
236
|
-
)
|
|
237
|
-
else:
|
|
238
|
-
self.var_names = list(reversed(var_names)) # y-values are upside down
|
|
239
|
-
|
|
240
|
-
self.combined = combined
|
|
241
|
-
self.combine_dims = combine_dims
|
|
242
|
-
|
|
243
|
-
if colors == "cycle":
|
|
244
|
-
colors = [
|
|
245
|
-
prop
|
|
246
|
-
for _, prop in zip(
|
|
247
|
-
range(len(self.data)), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
|
248
|
-
)
|
|
249
|
-
]
|
|
250
|
-
elif isinstance(colors, str):
|
|
251
|
-
colors = [colors for _ in self.data]
|
|
252
|
-
|
|
253
|
-
self.colors = list(reversed(colors)) # y-values are upside down
|
|
254
|
-
self.labeller = labeller
|
|
255
|
-
|
|
256
|
-
self.plotters = self.make_plotters()
|
|
257
|
-
|
|
258
|
-
def make_plotters(self):
|
|
259
|
-
"""Initialize an object for each variable to be plotted."""
|
|
260
|
-
plotters, y = {}, 0
|
|
261
|
-
for var_name in self.var_names:
|
|
262
|
-
plotters[var_name] = VarHandler(
|
|
263
|
-
var_name,
|
|
264
|
-
self.data,
|
|
265
|
-
y,
|
|
266
|
-
model_names=self.model_names,
|
|
267
|
-
combined=self.combined,
|
|
268
|
-
combine_dims=self.combine_dims,
|
|
269
|
-
colors=self.colors,
|
|
270
|
-
labeller=self.labeller,
|
|
271
|
-
)
|
|
272
|
-
y = plotters[var_name].y_max()
|
|
273
|
-
return plotters
|
|
274
|
-
|
|
275
|
-
def labels_and_ticks(self):
|
|
276
|
-
"""Collect labels and ticks from plotters."""
|
|
277
|
-
val = self.plotters.values()
|
|
278
|
-
|
|
279
|
-
def label_idxs():
|
|
280
|
-
labels, idxs = [], []
|
|
281
|
-
for plotter in val:
|
|
282
|
-
sub_labels, sub_idxs, _, _, _ = plotter.labels_ticks_and_vals()
|
|
283
|
-
labels_to_idxs = defaultdict(list)
|
|
284
|
-
for label, idx in zip(sub_labels, sub_idxs):
|
|
285
|
-
labels_to_idxs[label].append(idx)
|
|
286
|
-
sub_idxs = []
|
|
287
|
-
sub_labels = []
|
|
288
|
-
for label, all_idx in labels_to_idxs.items():
|
|
289
|
-
sub_labels.append(label)
|
|
290
|
-
sub_idxs.append(np.mean([j for j in all_idx]))
|
|
291
|
-
labels.append(sub_labels)
|
|
292
|
-
idxs.append(sub_idxs)
|
|
293
|
-
return np.concatenate(labels), np.concatenate(idxs)
|
|
294
|
-
|
|
295
|
-
return label_idxs()
|
|
296
|
-
|
|
297
|
-
def legend(self, ax, plotted):
|
|
298
|
-
"""Add interactive legend with colorcoded model info."""
|
|
299
|
-
legend_it = []
|
|
300
|
-
for model_name, glyphs in plotted.items():
|
|
301
|
-
legend_it.append((model_name, glyphs))
|
|
302
|
-
|
|
303
|
-
legend = Legend(items=legend_it, orientation="vertical", location="top_left")
|
|
304
|
-
ax.add_layout(legend, "above")
|
|
305
|
-
ax.legend.click_policy = "hide"
|
|
306
|
-
|
|
307
|
-
def display_multiple_ropes(
|
|
308
|
-
self, rope, ax, y, linewidth, var_name, selection, plotted, model_name
|
|
309
|
-
):
|
|
310
|
-
"""Display ROPE when more than one interval is provided."""
|
|
311
|
-
for sel in rope.get(var_name, []):
|
|
312
|
-
# pylint: disable=line-too-long
|
|
313
|
-
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
|
|
314
|
-
vals = sel["rope"]
|
|
315
|
-
plotted[model_name].append(
|
|
316
|
-
ax.line(
|
|
317
|
-
vals,
|
|
318
|
-
(y + 0.05, y + 0.05),
|
|
319
|
-
line_width=linewidth * 2,
|
|
320
|
-
color=[
|
|
321
|
-
color
|
|
322
|
-
for _, color in zip(
|
|
323
|
-
range(3), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
|
324
|
-
)
|
|
325
|
-
][2],
|
|
326
|
-
line_alpha=0.7,
|
|
327
|
-
)
|
|
328
|
-
)
|
|
329
|
-
return ax
|
|
330
|
-
|
|
331
|
-
def ridgeplot(
|
|
332
|
-
self,
|
|
333
|
-
hdi_prob,
|
|
334
|
-
mult,
|
|
335
|
-
linewidth,
|
|
336
|
-
markersize,
|
|
337
|
-
alpha,
|
|
338
|
-
ridgeplot_kind,
|
|
339
|
-
ridgeplot_truncate,
|
|
340
|
-
ridgeplot_quantiles,
|
|
341
|
-
ax,
|
|
342
|
-
plotted,
|
|
343
|
-
):
|
|
344
|
-
"""Draw ridgeplot for each plotter.
|
|
345
|
-
|
|
346
|
-
Parameters
|
|
347
|
-
----------
|
|
348
|
-
hdi_prob : float
|
|
349
|
-
Probability for the highest density interval.
|
|
350
|
-
mult : float
|
|
351
|
-
How much to multiply height by. Set this to greater than 1 to have some overlap.
|
|
352
|
-
linewidth : float
|
|
353
|
-
Width of line on border of ridges
|
|
354
|
-
markersize : float
|
|
355
|
-
Size of marker in center of forestplot line
|
|
356
|
-
alpha : float
|
|
357
|
-
Transparency of ridges
|
|
358
|
-
ridgeplot_kind : string
|
|
359
|
-
By default ("auto") continuous variables are plotted using KDEs and discrete ones using
|
|
360
|
-
histograms. To override this use "hist" to plot histograms and "density" for KDEs
|
|
361
|
-
ridgeplot_truncate: bool
|
|
362
|
-
Whether to truncate densities according to the value of hdi_prop. Defaults to True
|
|
363
|
-
ridgeplot_quantiles: list
|
|
364
|
-
Quantiles in ascending order used to segment the KDE. Use [.25, .5, .75] for quartiles.
|
|
365
|
-
Defaults to None.
|
|
366
|
-
ax : Axes
|
|
367
|
-
Axes to draw on
|
|
368
|
-
plotted : dict
|
|
369
|
-
Contains glyphs for each model
|
|
370
|
-
"""
|
|
371
|
-
if alpha is None:
|
|
372
|
-
alpha = 1.0
|
|
373
|
-
for plotter in list(self.plotters.values())[::-1]:
|
|
374
|
-
for x, y_min, y_max, hdi_, y_q, color, model_name in plotter.ridgeplot(
|
|
375
|
-
hdi_prob, mult, ridgeplot_kind
|
|
376
|
-
):
|
|
377
|
-
if alpha == 0:
|
|
378
|
-
border = color
|
|
379
|
-
facecolor = None
|
|
380
|
-
else:
|
|
381
|
-
border = "black"
|
|
382
|
-
facecolor = color
|
|
383
|
-
if x.dtype.kind == "i":
|
|
384
|
-
if ridgeplot_truncate:
|
|
385
|
-
y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
386
|
-
x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
387
|
-
else:
|
|
388
|
-
facecolor = color
|
|
389
|
-
alpha = [alpha if ci else 0 for ci in ((x >= hdi_[0]) & (x <= hdi_[1]))]
|
|
390
|
-
y_min = np.ones_like(x) * y_min
|
|
391
|
-
plotted[model_name].append(
|
|
392
|
-
ax.vbar(
|
|
393
|
-
x=x,
|
|
394
|
-
top=y_max - y_min,
|
|
395
|
-
bottom=y_min,
|
|
396
|
-
width=0.9,
|
|
397
|
-
line_color=border,
|
|
398
|
-
color=facecolor,
|
|
399
|
-
fill_alpha=alpha,
|
|
400
|
-
)
|
|
401
|
-
)
|
|
402
|
-
else:
|
|
403
|
-
tr_x = x[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
404
|
-
tr_y_min = np.ones_like(tr_x) * y_min
|
|
405
|
-
tr_y_max = y_max[(x >= hdi_[0]) & (x <= hdi_[1])]
|
|
406
|
-
y_min = np.ones_like(x) * y_min
|
|
407
|
-
patch = ax.patch(
|
|
408
|
-
np.concatenate([tr_x, tr_x[::-1]]),
|
|
409
|
-
np.concatenate([tr_y_min, tr_y_max[::-1]]),
|
|
410
|
-
fill_color=color,
|
|
411
|
-
fill_alpha=alpha,
|
|
412
|
-
line_width=0,
|
|
413
|
-
)
|
|
414
|
-
patch.level = "overlay"
|
|
415
|
-
plotted[model_name].append(patch)
|
|
416
|
-
if ridgeplot_truncate:
|
|
417
|
-
plotted[model_name].append(
|
|
418
|
-
ax.line(
|
|
419
|
-
x, y_max, line_dash="solid", line_width=linewidth, line_color=border
|
|
420
|
-
)
|
|
421
|
-
)
|
|
422
|
-
plotted[model_name].append(
|
|
423
|
-
ax.line(
|
|
424
|
-
x, y_min, line_dash="solid", line_width=linewidth, line_color=border
|
|
425
|
-
)
|
|
426
|
-
)
|
|
427
|
-
else:
|
|
428
|
-
plotted[model_name].append(
|
|
429
|
-
ax.line(
|
|
430
|
-
tr_x,
|
|
431
|
-
tr_y_max,
|
|
432
|
-
line_dash="solid",
|
|
433
|
-
line_width=linewidth,
|
|
434
|
-
line_color=border,
|
|
435
|
-
)
|
|
436
|
-
)
|
|
437
|
-
plotted[model_name].append(
|
|
438
|
-
ax.line(
|
|
439
|
-
tr_x,
|
|
440
|
-
tr_y_min,
|
|
441
|
-
line_dash="solid",
|
|
442
|
-
line_width=linewidth,
|
|
443
|
-
line_color=border,
|
|
444
|
-
)
|
|
445
|
-
)
|
|
446
|
-
if ridgeplot_quantiles is not None:
|
|
447
|
-
quantiles = [x[np.sum(y_q < quant)] for quant in ridgeplot_quantiles]
|
|
448
|
-
plotted[model_name].append(
|
|
449
|
-
ax.diamond(
|
|
450
|
-
quantiles,
|
|
451
|
-
np.ones_like(quantiles) * y_min[0],
|
|
452
|
-
line_color="black",
|
|
453
|
-
fill_color="black",
|
|
454
|
-
size=markersize,
|
|
455
|
-
)
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
return ax
|
|
459
|
-
|
|
460
|
-
def forestplot(self, hdi_prob, quartiles, linewidth, markersize, ax, rope, plotted):
|
|
461
|
-
"""Draw forestplot for each plotter.
|
|
462
|
-
|
|
463
|
-
Parameters
|
|
464
|
-
----------
|
|
465
|
-
hdi_prob : float
|
|
466
|
-
Probability for the highest density interval. Width of each line.
|
|
467
|
-
quartiles : bool
|
|
468
|
-
Whether to mark quartiles
|
|
469
|
-
linewidth : float
|
|
470
|
-
Width of forestplot line
|
|
471
|
-
markersize : float
|
|
472
|
-
Size of marker in center of forestplot line
|
|
473
|
-
ax : Axes
|
|
474
|
-
Axes to draw on
|
|
475
|
-
plotted : dict
|
|
476
|
-
Contains glyphs for each model
|
|
477
|
-
"""
|
|
478
|
-
if rope is None or isinstance(rope, dict):
|
|
479
|
-
pass
|
|
480
|
-
elif len(rope) == 2:
|
|
481
|
-
cds = ColumnDataSource(
|
|
482
|
-
{
|
|
483
|
-
"x": rope,
|
|
484
|
-
"lower": [-2 * self.y_max(), -2 * self.y_max()],
|
|
485
|
-
"upper": [self.y_max() * 2, self.y_max() * 2],
|
|
486
|
-
}
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
band = Band(
|
|
490
|
-
base="x",
|
|
491
|
-
lower="lower",
|
|
492
|
-
upper="upper",
|
|
493
|
-
fill_color=[
|
|
494
|
-
color
|
|
495
|
-
for _, color in zip(
|
|
496
|
-
range(4), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
|
497
|
-
)
|
|
498
|
-
][2],
|
|
499
|
-
line_alpha=0.5,
|
|
500
|
-
source=cds,
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
ax.renderers.append(band)
|
|
504
|
-
else:
|
|
505
|
-
raise ValueError(
|
|
506
|
-
"Argument `rope` must be None, a dictionary like"
|
|
507
|
-
'{"var_name": {"rope": (lo, hi)}}, or an '
|
|
508
|
-
"iterable of length 2"
|
|
509
|
-
)
|
|
510
|
-
# Quantiles to be calculated
|
|
511
|
-
endpoint = 100 * (1 - hdi_prob) / 2
|
|
512
|
-
if quartiles:
|
|
513
|
-
qlist = [endpoint, 25, 50, 75, 100 - endpoint]
|
|
514
|
-
else:
|
|
515
|
-
qlist = [endpoint, 50, 100 - endpoint]
|
|
516
|
-
|
|
517
|
-
for plotter in self.plotters.values():
|
|
518
|
-
for y, model_name, selection, values, color in plotter.treeplot(qlist, hdi_prob):
|
|
519
|
-
if isinstance(rope, dict):
|
|
520
|
-
self.display_multiple_ropes(
|
|
521
|
-
rope, ax, y, linewidth, plotter.var_name, selection, plotted, model_name
|
|
522
|
-
)
|
|
523
|
-
|
|
524
|
-
mid = len(values) // 2
|
|
525
|
-
param_iter = zip(
|
|
526
|
-
np.linspace(2 * linewidth, linewidth, mid, endpoint=True)[-1::-1], range(mid)
|
|
527
|
-
)
|
|
528
|
-
for width, j in param_iter:
|
|
529
|
-
plotted[model_name].append(
|
|
530
|
-
ax.line(
|
|
531
|
-
[values[j], values[-(j + 1)]],
|
|
532
|
-
[y, y],
|
|
533
|
-
line_width=width,
|
|
534
|
-
line_color=color,
|
|
535
|
-
)
|
|
536
|
-
)
|
|
537
|
-
plotted[model_name].append(
|
|
538
|
-
ax.scatter(
|
|
539
|
-
x=values[mid],
|
|
540
|
-
y=y,
|
|
541
|
-
marker="circle",
|
|
542
|
-
size=markersize * 0.75,
|
|
543
|
-
fill_color=color,
|
|
544
|
-
)
|
|
545
|
-
)
|
|
546
|
-
_title = Title()
|
|
547
|
-
_title.text = f"{hdi_prob:.1%} HDI"
|
|
548
|
-
ax.title = _title
|
|
549
|
-
|
|
550
|
-
return ax
|
|
551
|
-
|
|
552
|
-
def plot_neff(self, ax, markersize, plotted):
|
|
553
|
-
"""Draw effective n for each plotter."""
|
|
554
|
-
max_ess = 0
|
|
555
|
-
for plotter in self.plotters.values():
|
|
556
|
-
for y, ess, color, model_name in plotter.ess():
|
|
557
|
-
if ess is not None:
|
|
558
|
-
plotted[model_name].append(
|
|
559
|
-
ax.scatter(
|
|
560
|
-
x=ess,
|
|
561
|
-
y=y,
|
|
562
|
-
marker="circle",
|
|
563
|
-
fill_color=color,
|
|
564
|
-
size=markersize,
|
|
565
|
-
line_color="black",
|
|
566
|
-
)
|
|
567
|
-
)
|
|
568
|
-
if ess > max_ess:
|
|
569
|
-
max_ess = ess
|
|
570
|
-
ax.x_range._property_values["start"] = 0 # pylint: disable=protected-access
|
|
571
|
-
ax.x_range._property_values["end"] = 1.07 * max_ess # pylint: disable=protected-access
|
|
572
|
-
|
|
573
|
-
_title = Title()
|
|
574
|
-
_title.text = "ess"
|
|
575
|
-
ax.title = _title
|
|
576
|
-
|
|
577
|
-
ax.xaxis[0].ticker.desired_num_ticks = 3
|
|
578
|
-
|
|
579
|
-
return ax
|
|
580
|
-
|
|
581
|
-
def plot_rhat(self, ax, markersize, plotted):
|
|
582
|
-
"""Draw r-hat for each plotter."""
|
|
583
|
-
for plotter in self.plotters.values():
|
|
584
|
-
for y, r_hat, color, model_name in plotter.r_hat():
|
|
585
|
-
if r_hat is not None:
|
|
586
|
-
plotted[model_name].append(
|
|
587
|
-
ax.scatter(
|
|
588
|
-
x=r_hat,
|
|
589
|
-
y=y,
|
|
590
|
-
marker="circle",
|
|
591
|
-
fill_color=color,
|
|
592
|
-
size=markersize,
|
|
593
|
-
line_color="black",
|
|
594
|
-
)
|
|
595
|
-
)
|
|
596
|
-
ax.x_range._property_values["start"] = 0.9 # pylint: disable=protected-access
|
|
597
|
-
ax.x_range._property_values["end"] = 2.1 # pylint: disable=protected-access
|
|
598
|
-
|
|
599
|
-
_title = Title()
|
|
600
|
-
_title.text = "r_hat"
|
|
601
|
-
ax.title = _title
|
|
602
|
-
|
|
603
|
-
ax.xaxis[0].ticker.desired_num_ticks = 3
|
|
604
|
-
|
|
605
|
-
return ax
|
|
606
|
-
|
|
607
|
-
def fig_height(self):
|
|
608
|
-
"""Figure out the height of this plot."""
|
|
609
|
-
# hand-tuned
|
|
610
|
-
return (
|
|
611
|
-
4
|
|
612
|
-
+ len(self.data) * len(self.var_names)
|
|
613
|
-
- 1
|
|
614
|
-
+ 0.1 * sum(1 for j in self.plotters.values() for _ in j.iterator())
|
|
615
|
-
)
|
|
616
|
-
|
|
617
|
-
def y_max(self):
|
|
618
|
-
"""Get maximum y value for the plot."""
|
|
619
|
-
return max(p.y_max() for p in self.plotters.values())
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
class VarHandler:
|
|
623
|
-
"""Handle individual variable logic."""
|
|
624
|
-
|
|
625
|
-
def __init__(
|
|
626
|
-
self, var_name, data, y_start, model_names, combined, combine_dims, colors, labeller
|
|
627
|
-
):
|
|
628
|
-
self.var_name = var_name
|
|
629
|
-
self.data = data
|
|
630
|
-
self.y_start = y_start
|
|
631
|
-
self.model_names = model_names
|
|
632
|
-
self.combined = combined
|
|
633
|
-
self.combine_dims = combine_dims
|
|
634
|
-
self.colors = colors
|
|
635
|
-
self.labeller = labeller
|
|
636
|
-
self.model_color = dict(zip(self.model_names, self.colors))
|
|
637
|
-
max_chains = max(datum.chain.max().values for datum in data)
|
|
638
|
-
self.chain_offset = len(data) * 0.45 / max(1, max_chains)
|
|
639
|
-
self.var_offset = 1.5 * self.chain_offset
|
|
640
|
-
self.group_offset = 2 * self.var_offset
|
|
641
|
-
|
|
642
|
-
def iterator(self):
|
|
643
|
-
"""Iterate over models and chains for each variable."""
|
|
644
|
-
if self.combined:
|
|
645
|
-
grouped_data = [[(0, datum)] for datum in self.data]
|
|
646
|
-
skip_dims = self.combine_dims.union({"chain"})
|
|
647
|
-
else:
|
|
648
|
-
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
|
|
649
|
-
skip_dims = self.combine_dims
|
|
650
|
-
|
|
651
|
-
label_dict = OrderedDict()
|
|
652
|
-
selection_list = []
|
|
653
|
-
for name, grouped_datum in zip(self.model_names, grouped_data):
|
|
654
|
-
for _, sub_data in grouped_datum:
|
|
655
|
-
datum_iter = xarray_var_iter(
|
|
656
|
-
sub_data.squeeze(),
|
|
657
|
-
var_names=[self.var_name],
|
|
658
|
-
skip_dims=skip_dims,
|
|
659
|
-
reverse_selections=True,
|
|
660
|
-
)
|
|
661
|
-
datum_list = list(datum_iter)
|
|
662
|
-
for _, selection, isel, values in datum_list:
|
|
663
|
-
selection_list.append(selection)
|
|
664
|
-
if not selection or not len(selection_list) % len(datum_list):
|
|
665
|
-
var_name = self.var_name
|
|
666
|
-
else:
|
|
667
|
-
var_name = ""
|
|
668
|
-
label = self.labeller.make_label_flat(var_name, selection, isel)
|
|
669
|
-
if label not in label_dict:
|
|
670
|
-
label_dict[label] = OrderedDict()
|
|
671
|
-
if name not in label_dict[label]:
|
|
672
|
-
label_dict[label][name] = []
|
|
673
|
-
label_dict[label][name].append(values)
|
|
674
|
-
|
|
675
|
-
y = self.y_start
|
|
676
|
-
for idx, (label, model_data) in enumerate(label_dict.items()):
|
|
677
|
-
for model_name, value_list in model_data.items():
|
|
678
|
-
row_label = self.labeller.make_model_label(model_name, label)
|
|
679
|
-
for values in value_list:
|
|
680
|
-
yield y, row_label, model_name, label, selection_list[
|
|
681
|
-
idx
|
|
682
|
-
], values, self.model_color[model_name]
|
|
683
|
-
y += self.chain_offset
|
|
684
|
-
y += self.var_offset
|
|
685
|
-
y += self.group_offset
|
|
686
|
-
|
|
687
|
-
def labels_ticks_and_vals(self):
|
|
688
|
-
"""Get labels, ticks, values, and colors for the variable."""
|
|
689
|
-
y_ticks = defaultdict(list)
|
|
690
|
-
for y, label, model_name, _, _, vals, color in self.iterator():
|
|
691
|
-
y_ticks[label].append((y, vals, color, model_name))
|
|
692
|
-
labels, ticks, vals, colors, model_names = [], [], [], [], []
|
|
693
|
-
for label, all_data in y_ticks.items():
|
|
694
|
-
for data in all_data:
|
|
695
|
-
labels.append(label)
|
|
696
|
-
ticks.append(data[0])
|
|
697
|
-
vals.append(np.array(data[1]))
|
|
698
|
-
model_names.append(data[3])
|
|
699
|
-
colors.append(data[2]) # the colors are all the same
|
|
700
|
-
return labels, ticks, vals, colors, model_names
|
|
701
|
-
|
|
702
|
-
def treeplot(self, qlist, hdi_prob):
|
|
703
|
-
"""Get data for each treeplot for the variable."""
|
|
704
|
-
for y, _, model_name, _, selection, values, color in self.iterator():
|
|
705
|
-
ntiles = np.percentile(values.flatten(), qlist)
|
|
706
|
-
ntiles[0], ntiles[-1] = hdi(values.flatten(), hdi_prob, multimodal=False)
|
|
707
|
-
yield y, model_name, selection, ntiles, color
|
|
708
|
-
|
|
709
|
-
def ridgeplot(self, hdi_prob, mult, ridgeplot_kind):
|
|
710
|
-
"""Get data for each ridgeplot for the variable."""
|
|
711
|
-
xvals, hdi_vals, yvals, pdfs, pdfs_q, colors, model_names = [], [], [], [], [], [], []
|
|
712
|
-
|
|
713
|
-
for y, _, model_name, *_, values, color in self.iterator():
|
|
714
|
-
yvals.append(y)
|
|
715
|
-
colors.append(color)
|
|
716
|
-
model_names.append(model_name)
|
|
717
|
-
values = values.flatten()
|
|
718
|
-
values = values[np.isfinite(values)]
|
|
719
|
-
|
|
720
|
-
if hdi_prob != 1:
|
|
721
|
-
hdi_ = hdi(values, hdi_prob, multimodal=False)
|
|
722
|
-
else:
|
|
723
|
-
hdi_ = min(values), max(values)
|
|
724
|
-
|
|
725
|
-
if ridgeplot_kind == "auto":
|
|
726
|
-
kind = "hist" if np.all(np.mod(values, 1) == 0) else "density"
|
|
727
|
-
else:
|
|
728
|
-
kind = ridgeplot_kind
|
|
729
|
-
|
|
730
|
-
if kind == "hist":
|
|
731
|
-
bins = get_bins(values)
|
|
732
|
-
_, density, x = histogram(values, bins=bins)
|
|
733
|
-
x = x[:-1]
|
|
734
|
-
elif kind == "density":
|
|
735
|
-
x, density = kde(values)
|
|
736
|
-
|
|
737
|
-
density_q = density.cumsum() / density.sum()
|
|
738
|
-
|
|
739
|
-
xvals.append(x)
|
|
740
|
-
pdfs.append(density)
|
|
741
|
-
pdfs_q.append(density_q)
|
|
742
|
-
hdi_vals.append(hdi_)
|
|
743
|
-
|
|
744
|
-
scaling = max(np.max(j) for j in pdfs)
|
|
745
|
-
for y, x, hdi_val, pdf, pdf_q, color, model_name in zip(
|
|
746
|
-
yvals, xvals, hdi_vals, pdfs, pdfs_q, colors, model_names
|
|
747
|
-
):
|
|
748
|
-
yield x, y, mult * pdf / scaling + y, hdi_val, pdf_q, color, model_name
|
|
749
|
-
|
|
750
|
-
def ess(self):
|
|
751
|
-
"""Get effective n data for the variable."""
|
|
752
|
-
_, y_vals, values, colors, model_names = self.labels_ticks_and_vals()
|
|
753
|
-
for y, value, color, model_name in zip(y_vals, values, colors, model_names):
|
|
754
|
-
yield y, _ess(value), color, model_name
|
|
755
|
-
|
|
756
|
-
def r_hat(self):
|
|
757
|
-
"""Get rhat data for the variable."""
|
|
758
|
-
_, y_vals, values, colors, model_names = self.labels_ticks_and_vals()
|
|
759
|
-
for y, value, color, model_name in zip(y_vals, values, colors, model_names):
|
|
760
|
-
if value.ndim != 2 or value.shape[0] < 2:
|
|
761
|
-
yield y, None, color, model_name
|
|
762
|
-
else:
|
|
763
|
-
yield y, _rhat(value), color, model_name
|
|
764
|
-
|
|
765
|
-
def y_max(self):
|
|
766
|
-
"""Get max y value for the variable."""
|
|
767
|
-
end_y = max(y for y, *_ in self.iterator())
|
|
768
|
-
|
|
769
|
-
if self.combined:
|
|
770
|
-
end_y += self.group_offset
|
|
771
|
-
|
|
772
|
-
return end_y + 2 * self.group_offset
|