arviz 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +8 -3
- arviz/data/base.py +2 -2
- arviz/data/inference_data.py +57 -26
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +112 -4
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/__init__.py +8 -7
- arviz/plots/backends/bokeh/bpvplot.py +4 -3
- arviz/plots/backends/bokeh/densityplot.py +5 -1
- arviz/plots/backends/bokeh/dotplot.py +5 -2
- arviz/plots/backends/bokeh/essplot.py +4 -2
- arviz/plots/backends/bokeh/forestplot.py +11 -4
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/khatplot.py +4 -2
- arviz/plots/backends/bokeh/lmplot.py +28 -6
- arviz/plots/backends/bokeh/mcseplot.py +2 -2
- arviz/plots/backends/bokeh/pairplot.py +27 -52
- arviz/plots/backends/bokeh/ppcplot.py +2 -1
- arviz/plots/backends/bokeh/rankplot.py +2 -1
- arviz/plots/backends/bokeh/traceplot.py +2 -1
- arviz/plots/backends/bokeh/violinplot.py +2 -1
- arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/bfplot.py +9 -26
- arviz/plots/bpvplot.py +10 -1
- arviz/plots/hdiplot.py +5 -0
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/plots/plot_utils.py +5 -3
- arviz/preview.py +36 -5
- arviz/stats/__init__.py +1 -0
- arviz/stats/density_utils.py +1 -1
- arviz/stats/diagnostics.py +18 -14
- arviz/stats/stats.py +105 -7
- arviz/tests/base_tests/test_data.py +31 -11
- arviz/tests/base_tests/test_diagnostics.py +5 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
- arviz/tests/base_tests/test_stats.py +53 -1
- arviz/tests/external_tests/test_data_numpyro.py +130 -3
- arviz/utils.py +4 -0
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/RECORD +51 -51
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +1 -1
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
- {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
|
@@ -94,9 +94,10 @@ def plot_khat(
|
|
|
94
94
|
|
|
95
95
|
if not isinstance(rgba_c, str) and isinstance(rgba_c, Iterable):
|
|
96
96
|
for idx, (alpha, rgba_c_) in enumerate(zip(alphas, rgba_c)):
|
|
97
|
-
ax.
|
|
97
|
+
ax.scatter(
|
|
98
98
|
xdata[idx],
|
|
99
99
|
khats[idx],
|
|
100
|
+
marker="cross",
|
|
100
101
|
line_color=rgba_c_,
|
|
101
102
|
fill_color=rgba_c_,
|
|
102
103
|
line_alpha=alpha,
|
|
@@ -104,9 +105,10 @@ def plot_khat(
|
|
|
104
105
|
size=10,
|
|
105
106
|
)
|
|
106
107
|
else:
|
|
107
|
-
ax.
|
|
108
|
+
ax.scatter(
|
|
108
109
|
xdata,
|
|
109
110
|
khats,
|
|
111
|
+
marker="cross",
|
|
110
112
|
line_color=rgba_c,
|
|
111
113
|
fill_color=rgba_c,
|
|
112
114
|
size=10,
|
|
@@ -51,18 +51,30 @@ def plot_lm(
|
|
|
51
51
|
|
|
52
52
|
if y_kwargs is None:
|
|
53
53
|
y_kwargs = {}
|
|
54
|
+
else:
|
|
55
|
+
y_kwargs = y_kwargs.copy()
|
|
56
|
+
y_kwargs.setdefault("marker", "circle")
|
|
54
57
|
y_kwargs.setdefault("fill_color", "red")
|
|
55
58
|
y_kwargs.setdefault("line_width", 0)
|
|
56
59
|
y_kwargs.setdefault("size", 3)
|
|
57
60
|
|
|
58
61
|
if y_hat_plot_kwargs is None:
|
|
59
62
|
y_hat_plot_kwargs = {}
|
|
63
|
+
else:
|
|
64
|
+
y_hat_plot_kwargs = y_hat_plot_kwargs.copy()
|
|
65
|
+
y_hat_plot_kwargs.setdefault("marker", "circle")
|
|
60
66
|
y_hat_plot_kwargs.setdefault("fill_color", "orange")
|
|
61
67
|
y_hat_plot_kwargs.setdefault("line_width", 0)
|
|
62
68
|
|
|
63
69
|
if y_hat_fill_kwargs is None:
|
|
64
70
|
y_hat_fill_kwargs = {}
|
|
65
|
-
|
|
71
|
+
else:
|
|
72
|
+
y_hat_fill_kwargs = y_hat_fill_kwargs.copy()
|
|
73
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
74
|
+
if "color" in y_hat_fill_kwargs and "fill_color" not in y_hat_fill_kwargs:
|
|
75
|
+
y_hat_fill_kwargs["fill_color"] = y_hat_fill_kwargs.pop("color")
|
|
76
|
+
y_hat_fill_kwargs.setdefault("fill_color", "orange")
|
|
77
|
+
y_hat_fill_kwargs.setdefault("fill_alpha", 0.5)
|
|
66
78
|
|
|
67
79
|
if y_model_plot_kwargs is None:
|
|
68
80
|
y_model_plot_kwargs = {}
|
|
@@ -72,8 +84,13 @@ def plot_lm(
|
|
|
72
84
|
|
|
73
85
|
if y_model_fill_kwargs is None:
|
|
74
86
|
y_model_fill_kwargs = {}
|
|
75
|
-
|
|
76
|
-
|
|
87
|
+
else:
|
|
88
|
+
y_model_fill_kwargs = y_model_fill_kwargs.copy()
|
|
89
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
90
|
+
if "color" in y_model_fill_kwargs and "fill_color" not in y_model_fill_kwargs:
|
|
91
|
+
y_model_fill_kwargs["fill_color"] = y_model_fill_kwargs.pop("color")
|
|
92
|
+
y_model_fill_kwargs.setdefault("fill_color", "black")
|
|
93
|
+
y_model_fill_kwargs.setdefault("fill_alpha", 0.5)
|
|
77
94
|
|
|
78
95
|
if y_model_mean_kwargs is None:
|
|
79
96
|
y_model_mean_kwargs = {}
|
|
@@ -84,7 +101,7 @@ def plot_lm(
|
|
|
84
101
|
_, _, _, y_plotters = y[i]
|
|
85
102
|
_, _, _, x_plotters = x[i]
|
|
86
103
|
legend_it = []
|
|
87
|
-
observed_legend = ax_i.
|
|
104
|
+
observed_legend = ax_i.scatter(x_plotters, y_plotters, **y_kwargs)
|
|
88
105
|
legend_it.append(("Observed", [observed_legend]))
|
|
89
106
|
|
|
90
107
|
if y_hat is not None:
|
|
@@ -98,14 +115,14 @@ def plot_lm(
|
|
|
98
115
|
x_plotters_jitter = x_plotters + np.random.uniform(
|
|
99
116
|
low=-scale_high, high=scale_high, size=len(x_plotters)
|
|
100
117
|
)
|
|
101
|
-
posterior_circle = ax_i.
|
|
118
|
+
posterior_circle = ax_i.scatter(
|
|
102
119
|
x_plotters_jitter,
|
|
103
120
|
y_hat_plotters[..., j],
|
|
104
121
|
alpha=0.2,
|
|
105
122
|
**y_hat_plot_kwargs,
|
|
106
123
|
)
|
|
107
124
|
else:
|
|
108
|
-
posterior_circle = ax_i.
|
|
125
|
+
posterior_circle = ax_i.scatter(
|
|
109
126
|
x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs
|
|
110
127
|
)
|
|
111
128
|
posterior_legend.append(posterior_circle)
|
|
@@ -143,6 +160,11 @@ def plot_lm(
|
|
|
143
160
|
)
|
|
144
161
|
|
|
145
162
|
y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
|
|
163
|
+
# Plot mean line across all x values instead of just edges
|
|
164
|
+
mean_legend = ax_i.line(x_plotters, y_model_mean, **y_model_mean_kwargs)
|
|
165
|
+
legend_it.append(("Mean", [mean_legend]))
|
|
166
|
+
continue # Skip the edge plotting since we plotted full line
|
|
167
|
+
|
|
146
168
|
x_plotters_edge = [min(x_plotters), max(x_plotters)]
|
|
147
169
|
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
|
|
148
170
|
mean_legend = ax_i.line(x_plotters_edge, y_model_mean_edge, **y_model_mean_kwargs)
|
|
@@ -71,13 +71,13 @@ def plot_mcse(
|
|
|
71
71
|
values = data[var_name].sel(**selection).values.flatten()
|
|
72
72
|
if errorbar:
|
|
73
73
|
quantile_values = _quantile(values, probs)
|
|
74
|
-
ax_.
|
|
74
|
+
ax_.scatter(probs, quantile_values, marker="dash")
|
|
75
75
|
ax_.multi_line(
|
|
76
76
|
list(zip(probs, probs)),
|
|
77
77
|
[(quant - err, quant + err) for quant, err in zip(quantile_values, x)],
|
|
78
78
|
)
|
|
79
79
|
else:
|
|
80
|
-
ax_.
|
|
80
|
+
ax_.scatter(probs, x, marker="circle")
|
|
81
81
|
if extra_methods:
|
|
82
82
|
mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
|
|
83
83
|
sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
|
|
@@ -37,6 +37,8 @@ def plot_pair(
|
|
|
37
37
|
diverging_mask,
|
|
38
38
|
divergences_kwargs,
|
|
39
39
|
flat_var_names,
|
|
40
|
+
flat_ref_slices,
|
|
41
|
+
flat_var_labels,
|
|
40
42
|
backend_kwargs,
|
|
41
43
|
marginal_kwargs,
|
|
42
44
|
show,
|
|
@@ -72,61 +74,25 @@ def plot_pair(
|
|
|
72
74
|
kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)
|
|
73
75
|
|
|
74
76
|
if reference_values:
|
|
75
|
-
|
|
76
|
-
label = []
|
|
77
|
-
for variable in list(reference_values.keys()):
|
|
78
|
-
if " " in variable:
|
|
79
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
80
|
-
else:
|
|
81
|
-
variable_copy = variable
|
|
82
|
-
|
|
83
|
-
label.append(variable_copy)
|
|
84
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
85
|
-
|
|
86
|
-
difference = set(flat_var_names).difference(set(label))
|
|
87
|
-
|
|
88
|
-
if difference:
|
|
89
|
-
warn = [diff.replace("\n", " ", 1) for diff in difference]
|
|
90
|
-
warnings.warn(
|
|
91
|
-
"Argument reference_values does not include reference value for: {}".format(
|
|
92
|
-
", ".join(warn)
|
|
93
|
-
),
|
|
94
|
-
UserWarning,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
if reference_values:
|
|
98
|
-
reference_values_copy = {}
|
|
99
|
-
label = []
|
|
100
|
-
for variable in list(reference_values.keys()):
|
|
101
|
-
if " " in variable:
|
|
102
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
103
|
-
else:
|
|
104
|
-
variable_copy = variable
|
|
105
|
-
|
|
106
|
-
label.append(variable_copy)
|
|
107
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
108
|
-
|
|
109
|
-
difference = set(flat_var_names).difference(set(label))
|
|
110
|
-
|
|
111
|
-
for dif in difference:
|
|
112
|
-
reference_values_copy[dif] = None
|
|
77
|
+
difference = set(flat_var_names).difference(set(reference_values.keys()))
|
|
113
78
|
|
|
114
79
|
if difference:
|
|
115
|
-
warn = [dif.replace("\n", " ", 1) for dif in difference]
|
|
116
80
|
warnings.warn(
|
|
117
81
|
"Argument reference_values does not include reference value for: {}".format(
|
|
118
|
-
", ".join(
|
|
82
|
+
", ".join(difference)
|
|
119
83
|
),
|
|
120
84
|
UserWarning,
|
|
121
85
|
)
|
|
122
86
|
|
|
123
87
|
reference_values_kwargs = _init_kwargs_dict(reference_values_kwargs)
|
|
88
|
+
reference_values_kwargs.setdefault("marker", "circle")
|
|
124
89
|
reference_values_kwargs.setdefault("line_color", "black")
|
|
125
90
|
reference_values_kwargs.setdefault("fill_color", vectorized_to_hex("C2"))
|
|
126
91
|
reference_values_kwargs.setdefault("line_width", 1)
|
|
127
92
|
reference_values_kwargs.setdefault("size", 10)
|
|
128
93
|
|
|
129
94
|
divergences_kwargs = _init_kwargs_dict(divergences_kwargs)
|
|
95
|
+
divergences_kwargs.setdefault("marker", "circle")
|
|
130
96
|
divergences_kwargs.setdefault("line_color", "black")
|
|
131
97
|
divergences_kwargs.setdefault("fill_color", vectorized_to_hex("C1"))
|
|
132
98
|
divergences_kwargs.setdefault("line_width", 1)
|
|
@@ -155,6 +121,7 @@ def plot_pair(
|
|
|
155
121
|
)
|
|
156
122
|
|
|
157
123
|
point_estimate_marker_kwargs = _init_kwargs_dict(point_estimate_marker_kwargs)
|
|
124
|
+
point_estimate_marker_kwargs.setdefault("marker", "square")
|
|
158
125
|
point_estimate_marker_kwargs.setdefault("size", markersize)
|
|
159
126
|
point_estimate_marker_kwargs.setdefault("color", "black")
|
|
160
127
|
point_estimate_kwargs.setdefault("line_color", "black")
|
|
@@ -259,15 +226,17 @@ def plot_pair(
|
|
|
259
226
|
**marginal_kwargs,
|
|
260
227
|
)
|
|
261
228
|
|
|
262
|
-
ax[j, i].xaxis.axis_label =
|
|
263
|
-
ax[j, i].yaxis.axis_label =
|
|
229
|
+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
|
|
230
|
+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
|
|
264
231
|
|
|
265
232
|
elif j + marginals_offset > i:
|
|
266
233
|
if "scatter" in kind:
|
|
267
234
|
if divergences:
|
|
268
|
-
ax[j, i].
|
|
235
|
+
ax[j, i].scatter(
|
|
236
|
+
var1, var2, marker="circle", source=source, view=source_nondiv
|
|
237
|
+
)
|
|
269
238
|
else:
|
|
270
|
-
ax[j, i].
|
|
239
|
+
ax[j, i].scatter(var1, var2, marker="circle", source=source)
|
|
271
240
|
|
|
272
241
|
if "kde" in kind:
|
|
273
242
|
var1_kde = plotters[i][-1].flatten()
|
|
@@ -293,7 +262,7 @@ def plot_pair(
|
|
|
293
262
|
)
|
|
294
263
|
|
|
295
264
|
if divergences:
|
|
296
|
-
ax[j, i].
|
|
265
|
+
ax[j, i].scatter(
|
|
297
266
|
var1,
|
|
298
267
|
var2,
|
|
299
268
|
source=source,
|
|
@@ -306,7 +275,7 @@ def plot_pair(
|
|
|
306
275
|
var2_pe = plotters[j][-1].flatten()
|
|
307
276
|
pe_x = calculate_point_estimate(point_estimate, var1_pe)
|
|
308
277
|
pe_y = calculate_point_estimate(point_estimate, var2_pe)
|
|
309
|
-
ax[j, i].
|
|
278
|
+
ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
|
|
310
279
|
|
|
311
280
|
ax_hline = Span(
|
|
312
281
|
location=pe_y,
|
|
@@ -341,12 +310,18 @@ def plot_pair(
|
|
|
341
310
|
ax[-1, -1].add_layout(ax_pe_hline)
|
|
342
311
|
|
|
343
312
|
if reference_values:
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
if
|
|
347
|
-
ax[j, i].
|
|
348
|
-
|
|
349
|
-
|
|
313
|
+
x_name = flat_var_names[j + marginals_offset]
|
|
314
|
+
y_name = flat_var_names[i]
|
|
315
|
+
if (x_name not in difference) and (y_name not in difference):
|
|
316
|
+
ax[j, i].scatter(
|
|
317
|
+
np.array(reference_values[y_name])[flat_ref_slices[i]],
|
|
318
|
+
np.array(reference_values[x_name])[
|
|
319
|
+
flat_ref_slices[j + marginals_offset]
|
|
320
|
+
],
|
|
321
|
+
**reference_values_kwargs,
|
|
322
|
+
)
|
|
323
|
+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
|
|
324
|
+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
|
|
350
325
|
|
|
351
326
|
show_layout(ax, show)
|
|
352
327
|
|
|
@@ -313,9 +313,10 @@ def plot_ppc(
|
|
|
313
313
|
obs_yvals += np.random.uniform(
|
|
314
314
|
low=scale_low, high=scale_high, size=len(obs_vals)
|
|
315
315
|
)
|
|
316
|
-
glyph = ax_i.
|
|
316
|
+
glyph = ax_i.scatter(
|
|
317
317
|
obs_vals,
|
|
318
318
|
obs_yvals,
|
|
319
|
+
marker="circle",
|
|
319
320
|
line_color=colors[1],
|
|
320
321
|
fill_color=colors[1],
|
|
321
322
|
size=markersize,
|
|
@@ -49,6 +49,7 @@ def plot_rank(
|
|
|
49
49
|
|
|
50
50
|
if marker_vlines_kwargs is None:
|
|
51
51
|
marker_vlines_kwargs = {}
|
|
52
|
+
marker_vlines_kwargs.setdefault("marker", "circle")
|
|
52
53
|
|
|
53
54
|
if backend_kwargs is None:
|
|
54
55
|
backend_kwargs = {}
|
|
@@ -109,7 +110,7 @@ def plot_rank(
|
|
|
109
110
|
elif kind == "vlines":
|
|
110
111
|
ymin = np.full(len(all_counts), all_counts.mean())
|
|
111
112
|
for idx, counts in enumerate(all_counts):
|
|
112
|
-
ax.
|
|
113
|
+
ax.scatter(
|
|
113
114
|
bin_ary,
|
|
114
115
|
counts,
|
|
115
116
|
fill_color=colors[idx],
|
|
@@ -385,9 +385,10 @@ def _plot_chains_bokeh(
|
|
|
385
385
|
**dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx),
|
|
386
386
|
)
|
|
387
387
|
if marker:
|
|
388
|
-
ax_trace.
|
|
388
|
+
ax_trace.scatter(
|
|
389
389
|
x=x_name,
|
|
390
390
|
y=y_name,
|
|
391
|
+
marker="circle",
|
|
391
392
|
source=cds,
|
|
392
393
|
radius=0.30,
|
|
393
394
|
alpha=0.5,
|
|
@@ -81,9 +81,10 @@ def plot_violin(
|
|
|
81
81
|
[0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
|
|
82
82
|
)
|
|
83
83
|
ax_.line([0, 0], hdi_probs, line_width=linewidth, line_color="black", line_cap="round")
|
|
84
|
-
ax_.
|
|
84
|
+
ax_.scatter(
|
|
85
85
|
0,
|
|
86
86
|
per[-1],
|
|
87
|
+
marker="circle",
|
|
87
88
|
line_color="white",
|
|
88
89
|
fill_color="white",
|
|
89
90
|
size=linewidth * 1.5,
|
|
@@ -38,6 +38,7 @@ def plot_bpv(
|
|
|
38
38
|
plot_ref_kwargs,
|
|
39
39
|
backend_kwargs,
|
|
40
40
|
show,
|
|
41
|
+
smoothing,
|
|
41
42
|
):
|
|
42
43
|
"""Matplotlib bpv plot."""
|
|
43
44
|
if backend_kwargs is None:
|
|
@@ -87,7 +88,7 @@ def plot_bpv(
|
|
|
87
88
|
obs_vals = obs_vals.flatten()
|
|
88
89
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
|
89
90
|
|
|
90
|
-
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
91
|
+
if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
|
|
91
92
|
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
92
93
|
|
|
93
94
|
if kind == "p_value":
|
|
@@ -7,6 +7,7 @@ from matplotlib import cm
|
|
|
7
7
|
import matplotlib.pyplot as plt
|
|
8
8
|
import numpy as np
|
|
9
9
|
from matplotlib.colors import to_rgba_array
|
|
10
|
+
from packaging import version
|
|
10
11
|
|
|
11
12
|
from ....stats.density_utils import histogram
|
|
12
13
|
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
|
|
@@ -39,7 +40,13 @@ def plot_khat(
|
|
|
39
40
|
show,
|
|
40
41
|
):
|
|
41
42
|
"""Matplotlib khat plot."""
|
|
42
|
-
if
|
|
43
|
+
if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
|
|
44
|
+
interactive_backends = mpl.backends.backend_registry.list_builtin(
|
|
45
|
+
mpl.backends.BackendFilter.INTERACTIVE
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
interactive_backends = mpl.rcsetup.interactive_bk
|
|
49
|
+
if hover_label and mpl.get_backend() not in interactive_backends:
|
|
43
50
|
hover_label = False
|
|
44
51
|
warnings.warn(
|
|
45
52
|
"hover labels are only available with interactive backends. To switch to an "
|
|
@@ -115,12 +115,18 @@ def plot_lm(
|
|
|
115
115
|
|
|
116
116
|
if y_model is not None:
|
|
117
117
|
_, _, _, y_model_plotters = y_model[i]
|
|
118
|
+
|
|
118
119
|
if kind_model == "lines":
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
# y_model_plotters should be (points, samples)
|
|
121
|
+
y_points = y_model_plotters.shape[0]
|
|
122
|
+
if x_plotters.shape[0] == y_points:
|
|
123
|
+
for j in range(num_samples):
|
|
124
|
+
ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
|
|
125
|
+
|
|
126
|
+
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
|
|
127
|
+
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
128
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
122
129
|
|
|
123
|
-
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
124
130
|
else:
|
|
125
131
|
plot_hdi(
|
|
126
132
|
x_plotters,
|
|
@@ -128,10 +134,10 @@ def plot_lm(
|
|
|
128
134
|
fill_kwargs=y_model_fill_kwargs,
|
|
129
135
|
ax=ax_i,
|
|
130
136
|
)
|
|
131
|
-
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
132
137
|
|
|
133
|
-
|
|
134
|
-
|
|
138
|
+
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
139
|
+
y_model_mean = np.mean(y_model_plotters, axis=0)
|
|
140
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
135
141
|
|
|
136
142
|
if legend:
|
|
137
143
|
ax_i.legend(fontsize=xt_labelsize, loc="upper left")
|
|
@@ -30,6 +30,8 @@ def plot_pair(
|
|
|
30
30
|
diverging_mask,
|
|
31
31
|
divergences_kwargs,
|
|
32
32
|
flat_var_names,
|
|
33
|
+
flat_ref_slices,
|
|
34
|
+
flat_var_labels,
|
|
33
35
|
backend_kwargs,
|
|
34
36
|
marginal_kwargs,
|
|
35
37
|
show,
|
|
@@ -77,24 +79,12 @@ def plot_pair(
|
|
|
77
79
|
kde_kwargs["contour_kwargs"].setdefault("colors", "k")
|
|
78
80
|
|
|
79
81
|
if reference_values:
|
|
80
|
-
|
|
81
|
-
label = []
|
|
82
|
-
for variable in list(reference_values.keys()):
|
|
83
|
-
if " " in variable:
|
|
84
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
85
|
-
else:
|
|
86
|
-
variable_copy = variable
|
|
87
|
-
|
|
88
|
-
label.append(variable_copy)
|
|
89
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
90
|
-
|
|
91
|
-
difference = set(flat_var_names).difference(set(label))
|
|
82
|
+
difference = set(flat_var_names).difference(set(reference_values.keys()))
|
|
92
83
|
|
|
93
84
|
if difference:
|
|
94
|
-
warn = [diff.replace("\n", " ", 1) for diff in difference]
|
|
95
85
|
warnings.warn(
|
|
96
86
|
"Argument reference_values does not include reference value for: {}".format(
|
|
97
|
-
", ".join(
|
|
87
|
+
", ".join(difference)
|
|
98
88
|
),
|
|
99
89
|
UserWarning,
|
|
100
90
|
)
|
|
@@ -211,12 +201,12 @@ def plot_pair(
|
|
|
211
201
|
|
|
212
202
|
if reference_values:
|
|
213
203
|
ax.plot(
|
|
214
|
-
|
|
215
|
-
|
|
204
|
+
np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
|
|
205
|
+
np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
|
|
216
206
|
**reference_values_kwargs,
|
|
217
207
|
)
|
|
218
|
-
ax.set_xlabel(f"{
|
|
219
|
-
ax.set_ylabel(f"{
|
|
208
|
+
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
|
|
209
|
+
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
|
|
220
210
|
ax.tick_params(labelsize=xt_labelsize)
|
|
221
211
|
|
|
222
212
|
else:
|
|
@@ -336,20 +326,22 @@ def plot_pair(
|
|
|
336
326
|
y_name = flat_var_names[j + not_marginals]
|
|
337
327
|
if (x_name not in difference) and (y_name not in difference):
|
|
338
328
|
ax[j, i].plot(
|
|
339
|
-
|
|
340
|
-
|
|
329
|
+
np.array(reference_values[x_name])[flat_ref_slices[i]],
|
|
330
|
+
np.array(reference_values[y_name])[
|
|
331
|
+
flat_ref_slices[j + not_marginals]
|
|
332
|
+
],
|
|
341
333
|
**reference_values_kwargs,
|
|
342
334
|
)
|
|
343
335
|
|
|
344
336
|
if j != vars_to_plot - 1:
|
|
345
337
|
plt.setp(ax[j, i].get_xticklabels(), visible=False)
|
|
346
338
|
else:
|
|
347
|
-
ax[j, i].set_xlabel(f"{
|
|
339
|
+
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
|
|
348
340
|
if i != 0:
|
|
349
341
|
plt.setp(ax[j, i].get_yticklabels(), visible=False)
|
|
350
342
|
else:
|
|
351
343
|
ax[j, i].set_ylabel(
|
|
352
|
-
f"{
|
|
344
|
+
f"{flat_var_labels[j + not_marginals]}",
|
|
353
345
|
fontsize=ax_labelsize,
|
|
354
346
|
wrap=True,
|
|
355
347
|
)
|
arviz/plots/bfplot.py
CHANGED
|
@@ -2,11 +2,9 @@
|
|
|
2
2
|
# pylint: disable=unbalanced-tuple-unpacking
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
|
-
from numpy import interp
|
|
6
|
-
|
|
7
5
|
from ..data.utils import extract
|
|
8
6
|
from .plot_utils import get_plotting_function
|
|
9
|
-
from ..stats
|
|
7
|
+
from ..stats import bayes_factor
|
|
10
8
|
|
|
11
9
|
_log = logging.getLogger(__name__)
|
|
12
10
|
|
|
@@ -94,32 +92,17 @@ def plot_bf(
|
|
|
94
92
|
>>> az.plot_bf(idata, var_name="a", ref_val=0)
|
|
95
93
|
|
|
96
94
|
"""
|
|
97
|
-
posterior = extract(idata, var_names=var_name).values
|
|
98
|
-
|
|
99
|
-
if ref_val > posterior.max() or ref_val < posterior.min():
|
|
100
|
-
_log.warning(
|
|
101
|
-
"The reference value is outside of the posterior. "
|
|
102
|
-
"This translate into infinite support for H1, which is most likely an overstatement."
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
if posterior.ndim > 1:
|
|
106
|
-
_log.warning("Posterior distribution has {posterior.ndim} dimensions")
|
|
107
95
|
|
|
108
96
|
if prior is None:
|
|
109
97
|
prior = extract(idata, var_names=var_name, group="prior").values
|
|
110
98
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
elif posterior.dtype.kind == "i":
|
|
118
|
-
posterior_at_ref_val = (posterior == ref_val).mean()
|
|
119
|
-
prior_at_ref_val = (prior == ref_val).mean()
|
|
99
|
+
bf, p_at_ref_val = bayes_factor(
|
|
100
|
+
idata, var_name, prior=prior, ref_val=ref_val, return_ref_vals=True
|
|
101
|
+
)
|
|
102
|
+
bf_10 = bf["BF10"]
|
|
103
|
+
bf_01 = bf["BF01"]
|
|
120
104
|
|
|
121
|
-
|
|
122
|
-
bf_01 = 1 / bf_10
|
|
105
|
+
posterior = extract(idata, var_names=var_name)
|
|
123
106
|
|
|
124
107
|
bfplot_kwargs = dict(
|
|
125
108
|
ax=ax,
|
|
@@ -128,8 +111,8 @@ def plot_bf(
|
|
|
128
111
|
prior=prior,
|
|
129
112
|
posterior=posterior,
|
|
130
113
|
ref_val=ref_val,
|
|
131
|
-
prior_at_ref_val=
|
|
132
|
-
posterior_at_ref_val=
|
|
114
|
+
prior_at_ref_val=p_at_ref_val["prior"],
|
|
115
|
+
posterior_at_ref_val=p_at_ref_val["posterior"],
|
|
133
116
|
var_name=var_name,
|
|
134
117
|
colors=colors,
|
|
135
118
|
figsize=figsize,
|
arviz/plots/bpvplot.py
CHANGED
|
@@ -16,6 +16,7 @@ def plot_bpv(
|
|
|
16
16
|
bpv=True,
|
|
17
17
|
plot_mean=True,
|
|
18
18
|
reference="analytical",
|
|
19
|
+
smoothing=None,
|
|
19
20
|
mse=False,
|
|
20
21
|
n_ref=100,
|
|
21
22
|
hdi_prob=0.94,
|
|
@@ -72,6 +73,9 @@ def plot_bpv(
|
|
|
72
73
|
reference : {"analytical", "samples", None}, default "analytical"
|
|
73
74
|
How to compute the distributions used as reference for ``kind=u_values``
|
|
74
75
|
or ``kind=p_values``. Use `None` to not plot any reference.
|
|
76
|
+
smoothing : bool, optional
|
|
77
|
+
If True and the data has integer dtype, smooth the data before computing the p-values,
|
|
78
|
+
u-values or tstat. By default, True when `kind` is "u_value" and False otherwise.
|
|
75
79
|
mse : bool, default False
|
|
76
80
|
Show scaled mean square error between uniform distribution and marginal p_value
|
|
77
81
|
distribution.
|
|
@@ -166,7 +170,8 @@ def plot_bpv(
|
|
|
166
170
|
Notes
|
|
167
171
|
-----
|
|
168
172
|
Discrete data is smoothed before computing either p-values or u-values using the
|
|
169
|
-
function :func:`~arviz.smooth_data`
|
|
173
|
+
function :func:`~arviz.smooth_data` if the data is integer type
|
|
174
|
+
and the smoothing parameter is True.
|
|
170
175
|
|
|
171
176
|
Examples
|
|
172
177
|
--------
|
|
@@ -206,6 +211,9 @@ def plot_bpv(
|
|
|
206
211
|
elif not 1 >= hdi_prob > 0:
|
|
207
212
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
208
213
|
|
|
214
|
+
if smoothing is None:
|
|
215
|
+
smoothing = kind.lower() == "u_value"
|
|
216
|
+
|
|
209
217
|
if data_pairs is None:
|
|
210
218
|
data_pairs = {}
|
|
211
219
|
|
|
@@ -291,6 +299,7 @@ def plot_bpv(
|
|
|
291
299
|
plot_ref_kwargs=plot_ref_kwargs,
|
|
292
300
|
backend_kwargs=backend_kwargs,
|
|
293
301
|
show=show,
|
|
302
|
+
smoothing=smoothing,
|
|
294
303
|
)
|
|
295
304
|
|
|
296
305
|
# TODO: Add backend kwargs
|
arviz/plots/hdiplot.py
CHANGED
|
@@ -136,6 +136,11 @@ def plot_hdi(
|
|
|
136
136
|
x = np.asarray(x)
|
|
137
137
|
x_shape = x.shape
|
|
138
138
|
|
|
139
|
+
if isinstance(x[0], str):
|
|
140
|
+
raise NotImplementedError(
|
|
141
|
+
"The `arviz.plot_hdi()` function does not support categorical data. "
|
|
142
|
+
"Consider using `arviz.plot_forest()`."
|
|
143
|
+
)
|
|
139
144
|
if y is None and hdi_data is None:
|
|
140
145
|
raise ValueError("One of {y, hdi_data} is required")
|
|
141
146
|
if hdi_data is not None and y is not None:
|
arviz/plots/kdeplot.py
CHANGED
|
@@ -255,6 +255,10 @@ def plot_kde(
|
|
|
255
255
|
"or plot_pair instead of plot_kde"
|
|
256
256
|
)
|
|
257
257
|
|
|
258
|
+
if backend is None:
|
|
259
|
+
backend = rcParams["plot.backend"]
|
|
260
|
+
backend = backend.lower()
|
|
261
|
+
|
|
258
262
|
if values2 is None:
|
|
259
263
|
if bw == "default":
|
|
260
264
|
bw = "taylor" if is_circular else "experimental"
|
|
@@ -346,10 +350,6 @@ def plot_kde(
|
|
|
346
350
|
**kwargs,
|
|
347
351
|
)
|
|
348
352
|
|
|
349
|
-
if backend is None:
|
|
350
|
-
backend = rcParams["plot.backend"]
|
|
351
|
-
backend = backend.lower()
|
|
352
|
-
|
|
353
353
|
# TODO: Add backend kwargs
|
|
354
354
|
plot = get_plotting_function("plot_kde", "kdeplot", backend)
|
|
355
355
|
ax = plot(**kde_plot_args)
|