arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/inference_data.py +34 -7
- arviz/data/io_beanmachine.py +6 -1
- arviz/data/io_cmdstanpy.py +439 -50
- arviz/data/io_pyjags.py +5 -2
- arviz/data/io_pystan.py +1 -2
- arviz/labels.py +2 -0
- arviz/plots/backends/bokeh/bpvplot.py +7 -2
- arviz/plots/backends/bokeh/compareplot.py +7 -4
- arviz/plots/backends/bokeh/densityplot.py +0 -1
- arviz/plots/backends/bokeh/distplot.py +0 -2
- arviz/plots/backends/bokeh/forestplot.py +3 -5
- arviz/plots/backends/bokeh/kdeplot.py +0 -2
- arviz/plots/backends/bokeh/pairplot.py +0 -4
- arviz/plots/backends/matplotlib/bfplot.py +0 -1
- arviz/plots/backends/matplotlib/bpvplot.py +3 -3
- arviz/plots/backends/matplotlib/compareplot.py +1 -1
- arviz/plots/backends/matplotlib/dotplot.py +1 -1
- arviz/plots/backends/matplotlib/forestplot.py +2 -4
- arviz/plots/backends/matplotlib/kdeplot.py +0 -1
- arviz/plots/backends/matplotlib/khatplot.py +0 -1
- arviz/plots/backends/matplotlib/lmplot.py +4 -5
- arviz/plots/backends/matplotlib/pairplot.py +0 -1
- arviz/plots/backends/matplotlib/ppcplot.py +8 -5
- arviz/plots/backends/matplotlib/traceplot.py +1 -2
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +7 -2
- arviz/plots/compareplot.py +2 -2
- arviz/plots/ecdfplot.py +37 -112
- arviz/plots/elpdplot.py +1 -1
- arviz/plots/essplot.py +2 -2
- arviz/plots/kdeplot.py +0 -1
- arviz/plots/pairplot.py +1 -1
- arviz/plots/plot_utils.py +0 -1
- arviz/plots/ppcplot.py +51 -45
- arviz/plots/separationplot.py +0 -1
- arviz/stats/__init__.py +2 -0
- arviz/stats/density_utils.py +2 -2
- arviz/stats/diagnostics.py +2 -3
- arviz/stats/ecdf_utils.py +165 -0
- arviz/stats/stats.py +241 -38
- arviz/stats/stats_utils.py +36 -7
- arviz/tests/base_tests/test_data.py +73 -5
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1
- arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
- arviz/tests/base_tests/test_stats.py +43 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +3 -3
- arviz/tests/external_tests/test_data_beanmachine.py +2 -0
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyjags.py +3 -1
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +8 -8
- arviz/utils.py +15 -7
- arviz/wrappers/wrap_pymc.py +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
|
@@ -69,13 +69,14 @@ def plot_compare(
|
|
|
69
69
|
err_ys.append((y, y))
|
|
70
70
|
|
|
71
71
|
# plot them
|
|
72
|
-
dif_tri = ax.
|
|
72
|
+
dif_tri = ax.scatter(
|
|
73
73
|
comp_df[information_criterion].iloc[1:],
|
|
74
74
|
yticks_pos[1::2],
|
|
75
75
|
line_color=plot_kwargs.get("color_dse", "grey"),
|
|
76
76
|
fill_color=plot_kwargs.get("color_dse", "grey"),
|
|
77
77
|
line_width=2,
|
|
78
78
|
size=6,
|
|
79
|
+
marker="triangle",
|
|
79
80
|
)
|
|
80
81
|
dif_line = ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey"))
|
|
81
82
|
|
|
@@ -85,13 +86,14 @@ def plot_compare(
|
|
|
85
86
|
ax.yaxis.ticker = yticks_pos[::2]
|
|
86
87
|
ax.yaxis.major_label_overrides = dict(zip(yticks_pos[::2], yticks_labels))
|
|
87
88
|
|
|
88
|
-
elpd_circ = ax.
|
|
89
|
+
elpd_circ = ax.scatter(
|
|
89
90
|
comp_df[information_criterion],
|
|
90
91
|
yticks_pos[::2],
|
|
91
92
|
line_color=plot_kwargs.get("color_ic", "black"),
|
|
92
93
|
fill_color=None,
|
|
93
94
|
line_width=2,
|
|
94
95
|
size=6,
|
|
96
|
+
marker="circle",
|
|
95
97
|
)
|
|
96
98
|
elpd_label = [elpd_circ]
|
|
97
99
|
|
|
@@ -110,7 +112,7 @@ def plot_compare(
|
|
|
110
112
|
|
|
111
113
|
labels.append(("ELPD", elpd_label))
|
|
112
114
|
|
|
113
|
-
scale = comp_df["scale"][0]
|
|
115
|
+
scale = comp_df["scale"].iloc[0]
|
|
114
116
|
|
|
115
117
|
if insample_dev:
|
|
116
118
|
p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
|
|
@@ -120,13 +122,14 @@ def plot_compare(
|
|
|
120
122
|
correction = -p_ic
|
|
121
123
|
elif scale == "deviance":
|
|
122
124
|
correction = -(2 * p_ic)
|
|
123
|
-
insample_circ = ax.
|
|
125
|
+
insample_circ = ax.scatter(
|
|
124
126
|
comp_df[information_criterion] + correction,
|
|
125
127
|
yticks_pos[::2],
|
|
126
128
|
line_color=plot_kwargs.get("color_insample_dev", "black"),
|
|
127
129
|
fill_color=plot_kwargs.get("color_insample_dev", "black"),
|
|
128
130
|
line_width=2,
|
|
129
131
|
size=6,
|
|
132
|
+
marker="circle",
|
|
130
133
|
)
|
|
131
134
|
labels.append(("In-sample ELPD", [insample_circ]))
|
|
132
135
|
|
|
@@ -145,12 +145,10 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs, is_circular):
|
|
|
145
145
|
edges = edges.astype(float) - 0.5
|
|
146
146
|
|
|
147
147
|
if is_circular:
|
|
148
|
-
|
|
149
148
|
if is_circular == "degrees":
|
|
150
149
|
edges = np.deg2rad(edges)
|
|
151
150
|
labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"]
|
|
152
151
|
else:
|
|
153
|
-
|
|
154
152
|
labels = [
|
|
155
153
|
r"0",
|
|
156
154
|
r"π/4",
|
|
@@ -15,7 +15,6 @@ from ....rcparams import rcParams
|
|
|
15
15
|
from ....stats import hdi
|
|
16
16
|
from ....stats.density_utils import get_bins, histogram, kde
|
|
17
17
|
from ....stats.diagnostics import _ess, _rhat
|
|
18
|
-
from ....utils import conditional_jit
|
|
19
18
|
from ...plot_utils import _scale_fig_size
|
|
20
19
|
from .. import show_layout
|
|
21
20
|
from . import backend_kwarg_defaults
|
|
@@ -277,7 +276,6 @@ class PlotHandler:
|
|
|
277
276
|
"""Collect labels and ticks from plotters."""
|
|
278
277
|
val = self.plotters.values()
|
|
279
278
|
|
|
280
|
-
@conditional_jit(forceobj=True, nopython=False)
|
|
281
279
|
def label_idxs():
|
|
282
280
|
labels, idxs = [], []
|
|
283
281
|
for plotter in val:
|
|
@@ -299,7 +297,7 @@ class PlotHandler:
|
|
|
299
297
|
def legend(self, ax, plotted):
|
|
300
298
|
"""Add interactive legend with colorcoded model info."""
|
|
301
299
|
legend_it = []
|
|
302
|
-
for
|
|
300
|
+
for model_name, glyphs in plotted.items():
|
|
303
301
|
legend_it.append((model_name, glyphs))
|
|
304
302
|
|
|
305
303
|
legend = Legend(items=legend_it, orientation="vertical", location="top_left")
|
|
@@ -640,7 +638,7 @@ class VarHandler:
|
|
|
640
638
|
grouped_data = [[(0, datum)] for datum in self.data]
|
|
641
639
|
skip_dims = self.combine_dims.union({"chain"})
|
|
642
640
|
else:
|
|
643
|
-
grouped_data = [datum.groupby("chain") for datum in self.data]
|
|
641
|
+
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
|
|
644
642
|
skip_dims = self.combine_dims
|
|
645
643
|
|
|
646
644
|
label_dict = OrderedDict()
|
|
@@ -648,7 +646,7 @@ class VarHandler:
|
|
|
648
646
|
for name, grouped_datum in zip(self.model_names, grouped_data):
|
|
649
647
|
for _, sub_data in grouped_datum:
|
|
650
648
|
datum_iter = xarray_var_iter(
|
|
651
|
-
sub_data,
|
|
649
|
+
sub_data.squeeze(),
|
|
652
650
|
var_names=[self.var_name],
|
|
653
651
|
skip_dims=skip_dims,
|
|
654
652
|
reverse_selections=True,
|
|
@@ -165,7 +165,6 @@ def plot_kde(
|
|
|
165
165
|
x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s]
|
|
166
166
|
|
|
167
167
|
if contour:
|
|
168
|
-
|
|
169
168
|
scaled_density, *scaled_density_args = _scale_axis(density)
|
|
170
169
|
|
|
171
170
|
contourpy_kwargs = _init_kwargs_dict(contour_kwargs.pop("contourpy_kwargs", {}))
|
|
@@ -224,7 +223,6 @@ def plot_kde(
|
|
|
224
223
|
ax.ygrid.grid_line_color = None
|
|
225
224
|
|
|
226
225
|
else:
|
|
227
|
-
|
|
228
226
|
cmap = pcolormesh_kwargs.pop("cmap", "viridis")
|
|
229
227
|
if isinstance(cmap, str):
|
|
230
228
|
cmap = get_cmap(cmap)
|
|
@@ -241,11 +241,9 @@ def plot_pair(
|
|
|
241
241
|
|
|
242
242
|
# pylint: disable=too-many-nested-blocks
|
|
243
243
|
for i in range(0, numvars - marginals_offset):
|
|
244
|
-
|
|
245
244
|
var1 = flat_var_names[i] if tmp_flat_var_names is None else tmp_flat_var_names[i]
|
|
246
245
|
|
|
247
246
|
for j in range(0, numvars - marginals_offset):
|
|
248
|
-
|
|
249
247
|
var2 = (
|
|
250
248
|
flat_var_names[j + marginals_offset]
|
|
251
249
|
if tmp_flat_var_names is None
|
|
@@ -268,7 +266,6 @@ def plot_pair(
|
|
|
268
266
|
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
|
|
269
267
|
|
|
270
268
|
elif j + marginals_offset > i:
|
|
271
|
-
|
|
272
269
|
if "scatter" in kind:
|
|
273
270
|
if divergences:
|
|
274
271
|
ax[j, i].circle(var1, var2, source=source, view=source_nondiv)
|
|
@@ -328,7 +325,6 @@ def plot_pair(
|
|
|
328
325
|
ax[j, i].add_layout(ax_vline)
|
|
329
326
|
|
|
330
327
|
if marginals:
|
|
331
|
-
|
|
332
328
|
ax[j - 1, i].add_layout(ax_vline)
|
|
333
329
|
|
|
334
330
|
pe_last = calculate_point_estimate(point_estimate, plotters[-1][-1])
|
|
@@ -86,6 +86,9 @@ def plot_bpv(
|
|
|
86
86
|
obs_vals = obs_vals.flatten()
|
|
87
87
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
|
88
88
|
|
|
89
|
+
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
90
|
+
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
91
|
+
|
|
89
92
|
if kind == "p_value":
|
|
90
93
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
|
|
91
94
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
@@ -110,9 +113,6 @@ def plot_bpv(
|
|
|
110
113
|
ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
|
|
111
114
|
|
|
112
115
|
elif kind == "u_value":
|
|
113
|
-
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
114
|
-
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
115
|
-
|
|
116
116
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
|
|
117
117
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
118
118
|
ax_i.plot(x_s, tstat_pit_dens, color=color)
|
|
@@ -97,7 +97,7 @@ def plot_dot(
|
|
|
97
97
|
stack_locs, stack_count = wilkinson_algorithm(values, binwidth)
|
|
98
98
|
x, y = layout_stacks(stack_locs, stack_count, binwidth, stackratio, rotated)
|
|
99
99
|
|
|
100
|
-
for
|
|
100
|
+
for x_i, y_i in zip(x, y):
|
|
101
101
|
dot = plt.Circle((x_i, y_i), dotsize * binwidth / 2, **plot_kwargs)
|
|
102
102
|
ax.add_patch(dot)
|
|
103
103
|
|
|
@@ -11,7 +11,6 @@ from ....stats import hdi
|
|
|
11
11
|
from ....stats.density_utils import get_bins, histogram, kde
|
|
12
12
|
from ....stats.diagnostics import _ess, _rhat
|
|
13
13
|
from ....sel_utils import xarray_var_iter
|
|
14
|
-
from ....utils import conditional_jit
|
|
15
14
|
from ...plot_utils import _scale_fig_size
|
|
16
15
|
from . import backend_kwarg_defaults, backend_show
|
|
17
16
|
|
|
@@ -236,7 +235,6 @@ class PlotHandler:
|
|
|
236
235
|
"""Collect labels and ticks from plotters."""
|
|
237
236
|
val = self.plotters.values()
|
|
238
237
|
|
|
239
|
-
@conditional_jit(forceobj=True, nopython=False)
|
|
240
238
|
def label_idxs():
|
|
241
239
|
labels, idxs = [], []
|
|
242
240
|
for plotter in val:
|
|
@@ -536,7 +534,7 @@ class VarHandler:
|
|
|
536
534
|
grouped_data = [[(0, datum)] for datum in self.data]
|
|
537
535
|
skip_dims = self.combine_dims.union({"chain"})
|
|
538
536
|
else:
|
|
539
|
-
grouped_data = [datum.groupby("chain") for datum in self.data]
|
|
537
|
+
grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
|
|
540
538
|
skip_dims = self.combine_dims
|
|
541
539
|
|
|
542
540
|
label_dict = OrderedDict()
|
|
@@ -544,7 +542,7 @@ class VarHandler:
|
|
|
544
542
|
for name, grouped_datum in zip(self.model_names, grouped_data):
|
|
545
543
|
for _, sub_data in grouped_datum:
|
|
546
544
|
datum_iter = xarray_var_iter(
|
|
547
|
-
sub_data,
|
|
545
|
+
sub_data.squeeze(),
|
|
548
546
|
var_names=[self.var_name],
|
|
549
547
|
skip_dims=skip_dims,
|
|
550
548
|
reverse_selections=True,
|
|
@@ -50,7 +50,6 @@ def plot_lm(
|
|
|
50
50
|
_, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
|
|
51
51
|
|
|
52
52
|
for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
|
|
53
|
-
|
|
54
53
|
# All the kwargs are defined here beforehand
|
|
55
54
|
y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot")
|
|
56
55
|
y_kwargs.setdefault("color", "C3")
|
|
@@ -68,22 +67,22 @@ def plot_lm(
|
|
|
68
67
|
y_hat_plot_kwargs.setdefault("linewidth", 0)
|
|
69
68
|
|
|
70
69
|
y_hat_fill_kwargs = matplotlib_kwarg_dealiaser(y_hat_fill_kwargs, "fill_between")
|
|
71
|
-
y_hat_fill_kwargs.setdefault("color", "
|
|
70
|
+
y_hat_fill_kwargs.setdefault("color", "C3")
|
|
72
71
|
|
|
73
72
|
y_model_plot_kwargs = matplotlib_kwarg_dealiaser(y_model_plot_kwargs, "plot")
|
|
74
|
-
y_model_plot_kwargs.setdefault("color", "
|
|
73
|
+
y_model_plot_kwargs.setdefault("color", "C6")
|
|
75
74
|
y_model_plot_kwargs.setdefault("alpha", 0.5)
|
|
76
75
|
y_model_plot_kwargs.setdefault("linewidth", 0.5)
|
|
77
76
|
y_model_plot_kwargs.setdefault("zorder", 9)
|
|
78
77
|
|
|
79
78
|
y_model_fill_kwargs = matplotlib_kwarg_dealiaser(y_model_fill_kwargs, "fill_between")
|
|
80
|
-
y_model_fill_kwargs.setdefault("color", "
|
|
79
|
+
y_model_fill_kwargs.setdefault("color", "C0")
|
|
81
80
|
y_model_fill_kwargs.setdefault("linewidth", 0.5)
|
|
82
81
|
y_model_fill_kwargs.setdefault("zorder", 9)
|
|
83
82
|
y_model_fill_kwargs.setdefault("alpha", 0.5)
|
|
84
83
|
|
|
85
84
|
y_model_mean_kwargs = matplotlib_kwarg_dealiaser(y_model_mean_kwargs, "plot")
|
|
86
|
-
y_model_mean_kwargs.setdefault("color", "
|
|
85
|
+
y_model_mean_kwargs.setdefault("color", "C6")
|
|
87
86
|
y_model_mean_kwargs.setdefault("linewidth", 0.8)
|
|
88
87
|
y_model_mean_kwargs.setdefault("zorder", 11)
|
|
89
88
|
|
|
@@ -371,8 +371,6 @@ def plot_ppc(
|
|
|
371
371
|
if legend:
|
|
372
372
|
if i == 0:
|
|
373
373
|
ax_i.legend(fontsize=xt_labelsize * 0.75)
|
|
374
|
-
else:
|
|
375
|
-
ax_i.legend([])
|
|
376
374
|
|
|
377
375
|
if backend_show(show):
|
|
378
376
|
plt.show()
|
|
@@ -414,15 +412,20 @@ def _set_animation(
|
|
|
414
412
|
|
|
415
413
|
else:
|
|
416
414
|
vals = pp_sampled_vals[0]
|
|
417
|
-
|
|
415
|
+
bins = get_bins(vals)
|
|
416
|
+
_, y_vals, x_vals = histogram(vals, bins=bins)
|
|
418
417
|
(line,) = ax.plot(x_vals[:-1], y_vals, **plot_kwargs)
|
|
419
418
|
|
|
420
|
-
max_max = max(
|
|
419
|
+
max_max = max(
|
|
420
|
+
max(histogram(pp_sampled_vals[i], bins=get_bins(pp_sampled_vals[i]))[1])
|
|
421
|
+
for i in range(length)
|
|
422
|
+
)
|
|
421
423
|
|
|
422
424
|
ax.set_ylim(0, max_max)
|
|
423
425
|
|
|
424
426
|
def animate(i):
|
|
425
|
-
|
|
427
|
+
pp_vals = pp_sampled_vals[i]
|
|
428
|
+
_, y_vals, x_vals = histogram(pp_vals, bins=get_bins(pp_vals))
|
|
426
429
|
line.set_data(x_vals[:-1], y_vals)
|
|
427
430
|
return (line,)
|
|
428
431
|
|
|
@@ -430,7 +430,7 @@ def plot_trace(
|
|
|
430
430
|
Line2D(
|
|
431
431
|
[], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
|
|
432
432
|
)
|
|
433
|
-
for chain_id in range(data.
|
|
433
|
+
for chain_id in range(data.sizes["chain"])
|
|
434
434
|
]
|
|
435
435
|
if combined:
|
|
436
436
|
handles.insert(
|
|
@@ -470,7 +470,6 @@ def _plot_chains_mpl(
|
|
|
470
470
|
circ_var_units,
|
|
471
471
|
circ_units_trace,
|
|
472
472
|
):
|
|
473
|
-
|
|
474
473
|
if not circular:
|
|
475
474
|
circ_var_units = False
|
|
476
475
|
|
arviz/plots/bfplot.py
CHANGED
|
@@ -38,7 +38,7 @@ def plot_bf(
|
|
|
38
38
|
algorithm presented in [1]_.
|
|
39
39
|
|
|
40
40
|
Parameters
|
|
41
|
-
|
|
41
|
+
----------
|
|
42
42
|
idata : InferenceData
|
|
43
43
|
Any object that can be converted to an :class:`arviz.InferenceData` object
|
|
44
44
|
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
|
|
@@ -52,16 +52,16 @@ def plot_bf(
|
|
|
52
52
|
Tuple of valid Matplotlib colors. First element for the prior, second for the posterior.
|
|
53
53
|
figsize : (float, float), optional
|
|
54
54
|
Figure size. If `None` it will be defined automatically.
|
|
55
|
-
textsize: float, optional
|
|
55
|
+
textsize : float, optional
|
|
56
56
|
Text size scaling factor for labels, titles and lines. If `None` it will be auto
|
|
57
57
|
scaled based on `figsize`.
|
|
58
|
-
plot_kwargs :
|
|
58
|
+
plot_kwargs : dict, optional
|
|
59
59
|
Additional keywords passed to :func:`matplotlib.pyplot.plot`.
|
|
60
|
-
hist_kwargs :
|
|
60
|
+
hist_kwargs : dict, optional
|
|
61
61
|
Additional keywords passed to :func:`arviz.plot_dist`. Only works for discrete variables.
|
|
62
62
|
ax : axes, optional
|
|
63
63
|
:class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
|
|
64
|
-
backend :{"matplotlib", "bokeh"}, default "matplotlib"
|
|
64
|
+
backend : {"matplotlib", "bokeh"}, default "matplotlib"
|
|
65
65
|
Select plotting backend.
|
|
66
66
|
backend_kwargs : dict, optional
|
|
67
67
|
These are kwargs specific to the backend being used, passed to
|
|
@@ -78,7 +78,7 @@ def plot_bf(
|
|
|
78
78
|
References
|
|
79
79
|
----------
|
|
80
80
|
.. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
|
|
81
|
-
|
|
81
|
+
The case of computing Bayes factors for regression parameters.
|
|
82
82
|
|
|
83
83
|
Examples
|
|
84
84
|
--------
|
|
@@ -92,6 +92,7 @@ def plot_bf(
|
|
|
92
92
|
>>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
|
|
93
93
|
... prior={"a":np.random.normal(0, 1, 5000)})
|
|
94
94
|
>>> az.plot_bf(idata, var_name="a", ref_val=0)
|
|
95
|
+
|
|
95
96
|
"""
|
|
96
97
|
posterior = extract(idata, var_names=var_name).values
|
|
97
98
|
|
arviz/plots/bpvplot.py
CHANGED
|
@@ -162,6 +162,11 @@ def plot_bpv(
|
|
|
162
162
|
----------
|
|
163
163
|
* Gelman et al. (2013) see http://www.stat.columbia.edu/~gelman/book/ pages 151-153 for details
|
|
164
164
|
|
|
165
|
+
Notes
|
|
166
|
+
-----
|
|
167
|
+
Discrete data is smoothed before computing either p-values or u-values using the
|
|
168
|
+
function :func:`~arviz.smooth_data`
|
|
169
|
+
|
|
165
170
|
Examples
|
|
166
171
|
--------
|
|
167
172
|
Plot Bayesian p_values.
|
|
@@ -225,11 +230,11 @@ def plot_bpv(
|
|
|
225
230
|
|
|
226
231
|
if flatten_pp is None:
|
|
227
232
|
if flatten is None:
|
|
228
|
-
flatten_pp = list(predictive_dataset.dims
|
|
233
|
+
flatten_pp = list(predictive_dataset.dims)
|
|
229
234
|
else:
|
|
230
235
|
flatten_pp = flatten
|
|
231
236
|
if flatten is None:
|
|
232
|
-
flatten = list(observed.dims
|
|
237
|
+
flatten = list(observed.dims)
|
|
233
238
|
|
|
234
239
|
if coords is None:
|
|
235
240
|
coords = {}
|
arviz/plots/compareplot.py
CHANGED
|
@@ -90,10 +90,10 @@ def plot_compare(
|
|
|
90
90
|
References
|
|
91
91
|
----------
|
|
92
92
|
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
|
|
93
|
-
|
|
93
|
+
cross-validation and WAIC https://arxiv.org/abs/1507.04544
|
|
94
94
|
|
|
95
95
|
.. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
|
|
96
|
-
|
|
96
|
+
R and Stan, Second edition, CRC Press.
|
|
97
97
|
|
|
98
98
|
Examples
|
|
99
99
|
--------
|
arviz/plots/ecdfplot.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Plot ecdf or ecdf-difference plot with confidence bands."""
|
|
2
2
|
import numpy as np
|
|
3
|
-
from scipy.stats import uniform
|
|
3
|
+
from scipy.stats import uniform
|
|
4
4
|
|
|
5
5
|
from ..rcparams import rcParams
|
|
6
|
+
from ..stats.ecdf_utils import compute_ecdf, ecdf_confidence_band, _get_ecdf_points
|
|
6
7
|
from .plot_utils import get_plotting_function
|
|
7
8
|
|
|
8
9
|
|
|
@@ -26,7 +27,7 @@ def plot_ecdf(
|
|
|
26
27
|
show=None,
|
|
27
28
|
backend=None,
|
|
28
29
|
backend_kwargs=None,
|
|
29
|
-
**kwargs
|
|
30
|
+
**kwargs,
|
|
30
31
|
):
|
|
31
32
|
r"""Plot ECDF or ECDF-Difference Plot with Confidence bands.
|
|
32
33
|
|
|
@@ -48,6 +49,7 @@ def plot_ecdf(
|
|
|
48
49
|
Values to compare to the original sample.
|
|
49
50
|
cdf : callable, optional
|
|
50
51
|
Cumulative distribution function of the distribution to compare the original sample.
|
|
52
|
+
The function must take as input a numpy array of draws from the distribution.
|
|
51
53
|
difference : bool, default False
|
|
52
54
|
If True then plot ECDF-difference plot otherwise ECDF plot.
|
|
53
55
|
pit : bool, default False
|
|
@@ -180,75 +182,47 @@ def plot_ecdf(
|
|
|
180
182
|
values = np.ravel(values)
|
|
181
183
|
values.sort()
|
|
182
184
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
## Basically here we generate the evaluation points(x) and find the PIT values.
|
|
188
|
-
## z is the evaluation point for our uniform distribution in compute_gamma()
|
|
189
|
-
if pit:
|
|
190
|
-
x = np.linspace(1 / npoints, 1, npoints)
|
|
191
|
-
z = x
|
|
192
|
-
## Finding PIT for our sample
|
|
193
|
-
probs = cdf(values) if cdf else compute_ecdf(values2, values) / len(values2)
|
|
194
|
-
else:
|
|
195
|
-
## If not PIT use sample for plots and for evaluation points(x) use equally spaced
|
|
196
|
-
## points between minimum and maximum of sample
|
|
197
|
-
## For z we have used cdf(x)
|
|
198
|
-
x = np.linspace(values[0], values[-1], npoints)
|
|
199
|
-
z = cdf(x) if cdf else compute_ecdf(values2, x)
|
|
200
|
-
probs = values
|
|
201
|
-
|
|
202
|
-
n = len(values) # number of samples
|
|
203
|
-
## Computing gamma
|
|
204
|
-
gamma = fpr if pointwise else compute_gamma(n, z, npoints, num_trials, fpr)
|
|
205
|
-
## Using gamma to get the confidence intervals
|
|
206
|
-
lower, higher = get_lims(gamma, n, z)
|
|
207
|
-
|
|
208
|
-
## This block is for whether to plot ECDF or ECDF-difference
|
|
209
|
-
if not difference:
|
|
210
|
-
## We store the coordinates of our ecdf in x_coord, y_coord
|
|
211
|
-
x_coord, y_coord = get_ecdf_points(x, probs, difference)
|
|
185
|
+
if pit:
|
|
186
|
+
eval_points = np.linspace(1 / npoints, 1, npoints)
|
|
187
|
+
if cdf:
|
|
188
|
+
sample = cdf(values)
|
|
212
189
|
else:
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
y_coord[i] = y_coord[i] - (
|
|
217
|
-
x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
## Similarly we subtract from the upper and lower bounds
|
|
221
|
-
if pit:
|
|
222
|
-
lower = lower - x
|
|
223
|
-
higher = higher - x
|
|
224
|
-
else:
|
|
225
|
-
lower = lower - (cdf(x) if cdf else compute_ecdf(values2, x))
|
|
226
|
-
higher = higher - (cdf(x) if cdf else compute_ecdf(values2, x))
|
|
227
|
-
|
|
190
|
+
sample = compute_ecdf(values2, values) / len(values2)
|
|
191
|
+
cdf_at_eval_points = eval_points
|
|
192
|
+
rvs = uniform(0, 1).rvs
|
|
228
193
|
else:
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
194
|
+
eval_points = np.linspace(values[0], values[-1], npoints)
|
|
195
|
+
sample = values
|
|
196
|
+
if confidence_bands or difference:
|
|
197
|
+
if cdf:
|
|
198
|
+
cdf_at_eval_points = cdf(eval_points)
|
|
199
|
+
else:
|
|
200
|
+
cdf_at_eval_points = compute_ecdf(values2, eval_points)
|
|
232
201
|
else:
|
|
233
|
-
|
|
234
|
-
|
|
202
|
+
cdf_at_eval_points = np.zeros_like(eval_points)
|
|
203
|
+
rvs = None
|
|
204
|
+
|
|
205
|
+
x_coord, y_coord = _get_ecdf_points(sample, eval_points, difference)
|
|
235
206
|
|
|
207
|
+
if difference:
|
|
208
|
+
y_coord -= cdf_at_eval_points
|
|
209
|
+
|
|
210
|
+
if confidence_bands:
|
|
211
|
+
ndraws = len(values)
|
|
212
|
+
band_kwargs = {"prob": 1 - fpr, "num_trials": num_trials, "rvs": rvs, "random_state": None}
|
|
213
|
+
band_kwargs["method"] = "pointwise" if pointwise else "simulated"
|
|
214
|
+
lower, higher = ecdf_confidence_band(ndraws, eval_points, cdf_at_eval_points, **band_kwargs)
|
|
215
|
+
|
|
216
|
+
if difference:
|
|
217
|
+
lower -= cdf_at_eval_points
|
|
218
|
+
higher -= cdf_at_eval_points
|
|
219
|
+
else:
|
|
236
220
|
lower, higher = None, None
|
|
237
|
-
## This block is for whether to plot ECDF or ECDF-difference
|
|
238
|
-
if not difference:
|
|
239
|
-
x_coord, y_coord = get_ecdf_points(x, probs, difference)
|
|
240
|
-
else:
|
|
241
|
-
## Here we subtract the ecdf value as here we are plotting the ECDF-difference
|
|
242
|
-
x_coord, y_coord = get_ecdf_points(x, probs, difference)
|
|
243
|
-
for i, x_i in enumerate(x):
|
|
244
|
-
y_coord[i] = y_coord[i] - (
|
|
245
|
-
x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
|
|
246
|
-
)
|
|
247
221
|
|
|
248
222
|
ecdf_plot_args = dict(
|
|
249
223
|
x_coord=x_coord,
|
|
250
224
|
y_coord=y_coord,
|
|
251
|
-
x_bands=
|
|
225
|
+
x_bands=eval_points,
|
|
252
226
|
lower=lower,
|
|
253
227
|
higher=higher,
|
|
254
228
|
confidence_bands=confidence_bands,
|
|
@@ -260,7 +234,7 @@ def plot_ecdf(
|
|
|
260
234
|
ax=ax,
|
|
261
235
|
show=show,
|
|
262
236
|
backend_kwargs=backend_kwargs,
|
|
263
|
-
**kwargs
|
|
237
|
+
**kwargs,
|
|
264
238
|
)
|
|
265
239
|
|
|
266
240
|
if backend is None:
|
|
@@ -271,52 +245,3 @@ def plot_ecdf(
|
|
|
271
245
|
ax = plot(**ecdf_plot_args)
|
|
272
246
|
|
|
273
247
|
return ax
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
def compute_ecdf(sample, z):
|
|
277
|
-
"""Compute ECDF.
|
|
278
|
-
|
|
279
|
-
This function computes the ecdf value at the evaluation point
|
|
280
|
-
or a sorted set of evaluation points.
|
|
281
|
-
"""
|
|
282
|
-
return np.searchsorted(sample, z, side="right") / len(sample)
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
def get_ecdf_points(x, probs, difference):
|
|
286
|
-
"""Compute the coordinates for the ecdf points using compute_ecdf."""
|
|
287
|
-
y = compute_ecdf(probs, x)
|
|
288
|
-
|
|
289
|
-
if not difference:
|
|
290
|
-
x = np.insert(x, 0, x[0])
|
|
291
|
-
y = np.insert(y, 0, 0)
|
|
292
|
-
return x, y
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def compute_gamma(n, z, npoints=None, num_trials=1000, fpr=0.05):
|
|
296
|
-
"""Compute gamma for confidence interval calculation.
|
|
297
|
-
|
|
298
|
-
This function simulates an adjusted value of gamma to account for multiplicity
|
|
299
|
-
when forming an 1-fpr level confidence envelope for the ECDF of a sample.
|
|
300
|
-
"""
|
|
301
|
-
if npoints is None:
|
|
302
|
-
npoints = n
|
|
303
|
-
gamma = []
|
|
304
|
-
for _ in range(num_trials):
|
|
305
|
-
unif_samples = uniform.rvs(0, 1, n)
|
|
306
|
-
unif_samples = np.sort(unif_samples)
|
|
307
|
-
gamma_m = 1000
|
|
308
|
-
## Can compute ecdf for all the z together or one at a time.
|
|
309
|
-
f_z = compute_ecdf(unif_samples, z)
|
|
310
|
-
f_z = compute_ecdf(unif_samples, z)
|
|
311
|
-
gamma_m = 2 * min(
|
|
312
|
-
np.amin(binom.cdf(n * f_z, n, z)), np.amin(1 - binom.cdf(n * f_z - 1, n, z))
|
|
313
|
-
)
|
|
314
|
-
gamma.append(gamma_m)
|
|
315
|
-
return np.quantile(gamma, fpr)
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
def get_lims(gamma, n, z):
|
|
319
|
-
"""Compute the simultaneous 1 - fpr level confidence bands."""
|
|
320
|
-
lower = binom.ppf(gamma / 2, n, z)
|
|
321
|
-
upper = binom.ppf(1 - gamma / 2, n, z)
|
|
322
|
-
return lower / n, upper / n
|
arviz/plots/elpdplot.py
CHANGED
|
@@ -98,7 +98,7 @@ def plot_elpd(
|
|
|
98
98
|
References
|
|
99
99
|
----------
|
|
100
100
|
.. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
|
|
101
|
-
|
|
101
|
+
cross-validation and WAIC https://arxiv.org/abs/1507.04544
|
|
102
102
|
|
|
103
103
|
Examples
|
|
104
104
|
--------
|
arviz/plots/essplot.py
CHANGED
|
@@ -202,8 +202,8 @@ def plot_ess(
|
|
|
202
202
|
|
|
203
203
|
data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
|
|
204
204
|
var_names = _var_names(var_names, data, filter_vars)
|
|
205
|
-
n_draws = data.
|
|
206
|
-
n_samples = n_draws * data.
|
|
205
|
+
n_draws = data.sizes["draw"]
|
|
206
|
+
n_samples = n_draws * data.sizes["chain"]
|
|
207
207
|
|
|
208
208
|
ess_tail_dataset = None
|
|
209
209
|
mean_ess = None
|