arviz 0.20.0__py3-none-any.whl → 0.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. arviz/__init__.py +8 -3
  2. arviz/data/base.py +2 -2
  3. arviz/data/inference_data.py +57 -26
  4. arviz/data/io_datatree.py +2 -2
  5. arviz/data/io_numpyro.py +112 -4
  6. arviz/plots/autocorrplot.py +12 -2
  7. arviz/plots/backends/__init__.py +8 -7
  8. arviz/plots/backends/bokeh/bpvplot.py +4 -3
  9. arviz/plots/backends/bokeh/densityplot.py +5 -1
  10. arviz/plots/backends/bokeh/dotplot.py +5 -2
  11. arviz/plots/backends/bokeh/essplot.py +4 -2
  12. arviz/plots/backends/bokeh/forestplot.py +11 -4
  13. arviz/plots/backends/bokeh/hdiplot.py +7 -6
  14. arviz/plots/backends/bokeh/khatplot.py +4 -2
  15. arviz/plots/backends/bokeh/lmplot.py +28 -6
  16. arviz/plots/backends/bokeh/mcseplot.py +2 -2
  17. arviz/plots/backends/bokeh/pairplot.py +27 -52
  18. arviz/plots/backends/bokeh/ppcplot.py +2 -1
  19. arviz/plots/backends/bokeh/rankplot.py +2 -1
  20. arviz/plots/backends/bokeh/traceplot.py +2 -1
  21. arviz/plots/backends/bokeh/violinplot.py +2 -1
  22. arviz/plots/backends/matplotlib/bpvplot.py +2 -1
  23. arviz/plots/backends/matplotlib/khatplot.py +8 -1
  24. arviz/plots/backends/matplotlib/lmplot.py +13 -7
  25. arviz/plots/backends/matplotlib/pairplot.py +14 -22
  26. arviz/plots/bfplot.py +9 -26
  27. arviz/plots/bpvplot.py +10 -1
  28. arviz/plots/hdiplot.py +5 -0
  29. arviz/plots/kdeplot.py +4 -4
  30. arviz/plots/lmplot.py +41 -14
  31. arviz/plots/pairplot.py +10 -3
  32. arviz/plots/plot_utils.py +5 -3
  33. arviz/preview.py +36 -5
  34. arviz/stats/__init__.py +1 -0
  35. arviz/stats/density_utils.py +1 -1
  36. arviz/stats/diagnostics.py +18 -14
  37. arviz/stats/stats.py +105 -7
  38. arviz/tests/base_tests/test_data.py +31 -11
  39. arviz/tests/base_tests/test_diagnostics.py +5 -4
  40. arviz/tests/base_tests/test_plots_bokeh.py +60 -2
  41. arviz/tests/base_tests/test_plots_matplotlib.py +103 -11
  42. arviz/tests/base_tests/test_stats.py +53 -1
  43. arviz/tests/external_tests/test_data_numpyro.py +130 -3
  44. arviz/utils.py +4 -0
  45. arviz/wrappers/base.py +1 -1
  46. arviz/wrappers/wrap_stan.py +1 -1
  47. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
  48. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/RECORD +51 -51
  49. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +1 -1
  50. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
  51. {arviz-0.20.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
@@ -94,9 +94,10 @@ def plot_khat(
94
94
 
95
95
  if not isinstance(rgba_c, str) and isinstance(rgba_c, Iterable):
96
96
  for idx, (alpha, rgba_c_) in enumerate(zip(alphas, rgba_c)):
97
- ax.cross(
97
+ ax.scatter(
98
98
  xdata[idx],
99
99
  khats[idx],
100
+ marker="cross",
100
101
  line_color=rgba_c_,
101
102
  fill_color=rgba_c_,
102
103
  line_alpha=alpha,
@@ -104,9 +105,10 @@ def plot_khat(
104
105
  size=10,
105
106
  )
106
107
  else:
107
- ax.cross(
108
+ ax.scatter(
108
109
  xdata,
109
110
  khats,
111
+ marker="cross",
110
112
  line_color=rgba_c,
111
113
  fill_color=rgba_c,
112
114
  size=10,
@@ -51,18 +51,30 @@ def plot_lm(
51
51
 
52
52
  if y_kwargs is None:
53
53
  y_kwargs = {}
54
+ else:
55
+ y_kwargs = y_kwargs.copy()
56
+ y_kwargs.setdefault("marker", "circle")
54
57
  y_kwargs.setdefault("fill_color", "red")
55
58
  y_kwargs.setdefault("line_width", 0)
56
59
  y_kwargs.setdefault("size", 3)
57
60
 
58
61
  if y_hat_plot_kwargs is None:
59
62
  y_hat_plot_kwargs = {}
63
+ else:
64
+ y_hat_plot_kwargs = y_hat_plot_kwargs.copy()
65
+ y_hat_plot_kwargs.setdefault("marker", "circle")
60
66
  y_hat_plot_kwargs.setdefault("fill_color", "orange")
61
67
  y_hat_plot_kwargs.setdefault("line_width", 0)
62
68
 
63
69
  if y_hat_fill_kwargs is None:
64
70
  y_hat_fill_kwargs = {}
65
- y_hat_fill_kwargs.setdefault("color", "orange")
71
+ else:
72
+ y_hat_fill_kwargs = y_hat_fill_kwargs.copy()
73
+ # Convert matplotlib color to bokeh fill_color if needed
74
+ if "color" in y_hat_fill_kwargs and "fill_color" not in y_hat_fill_kwargs:
75
+ y_hat_fill_kwargs["fill_color"] = y_hat_fill_kwargs.pop("color")
76
+ y_hat_fill_kwargs.setdefault("fill_color", "orange")
77
+ y_hat_fill_kwargs.setdefault("fill_alpha", 0.5)
66
78
 
67
79
  if y_model_plot_kwargs is None:
68
80
  y_model_plot_kwargs = {}
@@ -72,8 +84,13 @@ def plot_lm(
72
84
 
73
85
  if y_model_fill_kwargs is None:
74
86
  y_model_fill_kwargs = {}
75
- y_model_fill_kwargs.setdefault("color", "black")
76
- y_model_fill_kwargs.setdefault("alpha", 0.5)
87
+ else:
88
+ y_model_fill_kwargs = y_model_fill_kwargs.copy()
89
+ # Convert matplotlib color to bokeh fill_color if needed
90
+ if "color" in y_model_fill_kwargs and "fill_color" not in y_model_fill_kwargs:
91
+ y_model_fill_kwargs["fill_color"] = y_model_fill_kwargs.pop("color")
92
+ y_model_fill_kwargs.setdefault("fill_color", "black")
93
+ y_model_fill_kwargs.setdefault("fill_alpha", 0.5)
77
94
 
78
95
  if y_model_mean_kwargs is None:
79
96
  y_model_mean_kwargs = {}
@@ -84,7 +101,7 @@ def plot_lm(
84
101
  _, _, _, y_plotters = y[i]
85
102
  _, _, _, x_plotters = x[i]
86
103
  legend_it = []
87
- observed_legend = ax_i.circle(x_plotters, y_plotters, **y_kwargs)
104
+ observed_legend = ax_i.scatter(x_plotters, y_plotters, **y_kwargs)
88
105
  legend_it.append(("Observed", [observed_legend]))
89
106
 
90
107
  if y_hat is not None:
@@ -98,14 +115,14 @@ def plot_lm(
98
115
  x_plotters_jitter = x_plotters + np.random.uniform(
99
116
  low=-scale_high, high=scale_high, size=len(x_plotters)
100
117
  )
101
- posterior_circle = ax_i.circle(
118
+ posterior_circle = ax_i.scatter(
102
119
  x_plotters_jitter,
103
120
  y_hat_plotters[..., j],
104
121
  alpha=0.2,
105
122
  **y_hat_plot_kwargs,
106
123
  )
107
124
  else:
108
- posterior_circle = ax_i.circle(
125
+ posterior_circle = ax_i.scatter(
109
126
  x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs
110
127
  )
111
128
  posterior_legend.append(posterior_circle)
@@ -143,6 +160,11 @@ def plot_lm(
143
160
  )
144
161
 
145
162
  y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
163
+ # Plot mean line across all x values instead of just edges
164
+ mean_legend = ax_i.line(x_plotters, y_model_mean, **y_model_mean_kwargs)
165
+ legend_it.append(("Mean", [mean_legend]))
166
+ continue # Skip the edge plotting since we plotted full line
167
+
146
168
  x_plotters_edge = [min(x_plotters), max(x_plotters)]
147
169
  y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
148
170
  mean_legend = ax_i.line(x_plotters_edge, y_model_mean_edge, **y_model_mean_kwargs)
@@ -71,13 +71,13 @@ def plot_mcse(
71
71
  values = data[var_name].sel(**selection).values.flatten()
72
72
  if errorbar:
73
73
  quantile_values = _quantile(values, probs)
74
- ax_.dash(probs, quantile_values)
74
+ ax_.scatter(probs, quantile_values, marker="dash")
75
75
  ax_.multi_line(
76
76
  list(zip(probs, probs)),
77
77
  [(quant - err, quant + err) for quant, err in zip(quantile_values, x)],
78
78
  )
79
79
  else:
80
- ax_.circle(probs, x)
80
+ ax_.scatter(probs, x, marker="circle")
81
81
  if extra_methods:
82
82
  mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
83
83
  sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
@@ -37,6 +37,8 @@ def plot_pair(
37
37
  diverging_mask,
38
38
  divergences_kwargs,
39
39
  flat_var_names,
40
+ flat_ref_slices,
41
+ flat_var_labels,
40
42
  backend_kwargs,
41
43
  marginal_kwargs,
42
44
  show,
@@ -72,61 +74,25 @@ def plot_pair(
72
74
  kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)
73
75
 
74
76
  if reference_values:
75
- reference_values_copy = {}
76
- label = []
77
- for variable in list(reference_values.keys()):
78
- if " " in variable:
79
- variable_copy = variable.replace(" ", "\n", 1)
80
- else:
81
- variable_copy = variable
82
-
83
- label.append(variable_copy)
84
- reference_values_copy[variable_copy] = reference_values[variable]
85
-
86
- difference = set(flat_var_names).difference(set(label))
87
-
88
- if difference:
89
- warn = [diff.replace("\n", " ", 1) for diff in difference]
90
- warnings.warn(
91
- "Argument reference_values does not include reference value for: {}".format(
92
- ", ".join(warn)
93
- ),
94
- UserWarning,
95
- )
96
-
97
- if reference_values:
98
- reference_values_copy = {}
99
- label = []
100
- for variable in list(reference_values.keys()):
101
- if " " in variable:
102
- variable_copy = variable.replace(" ", "\n", 1)
103
- else:
104
- variable_copy = variable
105
-
106
- label.append(variable_copy)
107
- reference_values_copy[variable_copy] = reference_values[variable]
108
-
109
- difference = set(flat_var_names).difference(set(label))
110
-
111
- for dif in difference:
112
- reference_values_copy[dif] = None
77
+ difference = set(flat_var_names).difference(set(reference_values.keys()))
113
78
 
114
79
  if difference:
115
- warn = [dif.replace("\n", " ", 1) for dif in difference]
116
80
  warnings.warn(
117
81
  "Argument reference_values does not include reference value for: {}".format(
118
- ", ".join(warn)
82
+ ", ".join(difference)
119
83
  ),
120
84
  UserWarning,
121
85
  )
122
86
 
123
87
  reference_values_kwargs = _init_kwargs_dict(reference_values_kwargs)
88
+ reference_values_kwargs.setdefault("marker", "circle")
124
89
  reference_values_kwargs.setdefault("line_color", "black")
125
90
  reference_values_kwargs.setdefault("fill_color", vectorized_to_hex("C2"))
126
91
  reference_values_kwargs.setdefault("line_width", 1)
127
92
  reference_values_kwargs.setdefault("size", 10)
128
93
 
129
94
  divergences_kwargs = _init_kwargs_dict(divergences_kwargs)
95
+ divergences_kwargs.setdefault("marker", "circle")
130
96
  divergences_kwargs.setdefault("line_color", "black")
131
97
  divergences_kwargs.setdefault("fill_color", vectorized_to_hex("C1"))
132
98
  divergences_kwargs.setdefault("line_width", 1)
@@ -155,6 +121,7 @@ def plot_pair(
155
121
  )
156
122
 
157
123
  point_estimate_marker_kwargs = _init_kwargs_dict(point_estimate_marker_kwargs)
124
+ point_estimate_marker_kwargs.setdefault("marker", "square")
158
125
  point_estimate_marker_kwargs.setdefault("size", markersize)
159
126
  point_estimate_marker_kwargs.setdefault("color", "black")
160
127
  point_estimate_kwargs.setdefault("line_color", "black")
@@ -259,15 +226,17 @@ def plot_pair(
259
226
  **marginal_kwargs,
260
227
  )
261
228
 
262
- ax[j, i].xaxis.axis_label = flat_var_names[i]
263
- ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
229
+ ax[j, i].xaxis.axis_label = flat_var_labels[i]
230
+ ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
264
231
 
265
232
  elif j + marginals_offset > i:
266
233
  if "scatter" in kind:
267
234
  if divergences:
268
- ax[j, i].circle(var1, var2, source=source, view=source_nondiv)
235
+ ax[j, i].scatter(
236
+ var1, var2, marker="circle", source=source, view=source_nondiv
237
+ )
269
238
  else:
270
- ax[j, i].circle(var1, var2, source=source)
239
+ ax[j, i].scatter(var1, var2, marker="circle", source=source)
271
240
 
272
241
  if "kde" in kind:
273
242
  var1_kde = plotters[i][-1].flatten()
@@ -293,7 +262,7 @@ def plot_pair(
293
262
  )
294
263
 
295
264
  if divergences:
296
- ax[j, i].circle(
265
+ ax[j, i].scatter(
297
266
  var1,
298
267
  var2,
299
268
  source=source,
@@ -306,7 +275,7 @@ def plot_pair(
306
275
  var2_pe = plotters[j][-1].flatten()
307
276
  pe_x = calculate_point_estimate(point_estimate, var1_pe)
308
277
  pe_y = calculate_point_estimate(point_estimate, var2_pe)
309
- ax[j, i].square(pe_x, pe_y, **point_estimate_marker_kwargs)
278
+ ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
310
279
 
311
280
  ax_hline = Span(
312
281
  location=pe_y,
@@ -341,12 +310,18 @@ def plot_pair(
341
310
  ax[-1, -1].add_layout(ax_pe_hline)
342
311
 
343
312
  if reference_values:
344
- x = reference_values_copy[flat_var_names[j + marginals_offset]]
345
- y = reference_values_copy[flat_var_names[i]]
346
- if x and y:
347
- ax[j, i].circle(y, x, **reference_values_kwargs)
348
- ax[j, i].xaxis.axis_label = flat_var_names[i]
349
- ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
313
+ x_name = flat_var_names[j + marginals_offset]
314
+ y_name = flat_var_names[i]
315
+ if (x_name not in difference) and (y_name not in difference):
316
+ ax[j, i].scatter(
317
+ np.array(reference_values[y_name])[flat_ref_slices[i]],
318
+ np.array(reference_values[x_name])[
319
+ flat_ref_slices[j + marginals_offset]
320
+ ],
321
+ **reference_values_kwargs,
322
+ )
323
+ ax[j, i].xaxis.axis_label = flat_var_labels[i]
324
+ ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
350
325
 
351
326
  show_layout(ax, show)
352
327
 
@@ -313,9 +313,10 @@ def plot_ppc(
313
313
  obs_yvals += np.random.uniform(
314
314
  low=scale_low, high=scale_high, size=len(obs_vals)
315
315
  )
316
- glyph = ax_i.circle(
316
+ glyph = ax_i.scatter(
317
317
  obs_vals,
318
318
  obs_yvals,
319
+ marker="circle",
319
320
  line_color=colors[1],
320
321
  fill_color=colors[1],
321
322
  size=markersize,
@@ -49,6 +49,7 @@ def plot_rank(
49
49
 
50
50
  if marker_vlines_kwargs is None:
51
51
  marker_vlines_kwargs = {}
52
+ marker_vlines_kwargs.setdefault("marker", "circle")
52
53
 
53
54
  if backend_kwargs is None:
54
55
  backend_kwargs = {}
@@ -109,7 +110,7 @@ def plot_rank(
109
110
  elif kind == "vlines":
110
111
  ymin = np.full(len(all_counts), all_counts.mean())
111
112
  for idx, counts in enumerate(all_counts):
112
- ax.circle(
113
+ ax.scatter(
113
114
  bin_ary,
114
115
  counts,
115
116
  fill_color=colors[idx],
@@ -385,9 +385,10 @@ def _plot_chains_bokeh(
385
385
  **dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx),
386
386
  )
387
387
  if marker:
388
- ax_trace.circle(
388
+ ax_trace.scatter(
389
389
  x=x_name,
390
390
  y=y_name,
391
+ marker="circle",
391
392
  source=cds,
392
393
  radius=0.30,
393
394
  alpha=0.5,
@@ -81,9 +81,10 @@ def plot_violin(
81
81
  [0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
82
82
  )
83
83
  ax_.line([0, 0], hdi_probs, line_width=linewidth, line_color="black", line_cap="round")
84
- ax_.circle(
84
+ ax_.scatter(
85
85
  0,
86
86
  per[-1],
87
+ marker="circle",
87
88
  line_color="white",
88
89
  fill_color="white",
89
90
  size=linewidth * 1.5,
@@ -38,6 +38,7 @@ def plot_bpv(
38
38
  plot_ref_kwargs,
39
39
  backend_kwargs,
40
40
  show,
41
+ smoothing,
41
42
  ):
42
43
  """Matplotlib bpv plot."""
43
44
  if backend_kwargs is None:
@@ -87,7 +88,7 @@ def plot_bpv(
87
88
  obs_vals = obs_vals.flatten()
88
89
  pp_vals = pp_vals.reshape(total_pp_samples, -1)
89
90
 
90
- if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
91
+ if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
91
92
  obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
92
93
 
93
94
  if kind == "p_value":
@@ -7,6 +7,7 @@ from matplotlib import cm
7
7
  import matplotlib.pyplot as plt
8
8
  import numpy as np
9
9
  from matplotlib.colors import to_rgba_array
10
+ from packaging import version
10
11
 
11
12
  from ....stats.density_utils import histogram
12
13
  from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
@@ -39,7 +40,13 @@ def plot_khat(
39
40
  show,
40
41
  ):
41
42
  """Matplotlib khat plot."""
42
- if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk:
43
+ if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
44
+ interactive_backends = mpl.backends.backend_registry.list_builtin(
45
+ mpl.backends.BackendFilter.INTERACTIVE
46
+ )
47
+ else:
48
+ interactive_backends = mpl.rcsetup.interactive_bk
49
+ if hover_label and mpl.get_backend() not in interactive_backends:
43
50
  hover_label = False
44
51
  warnings.warn(
45
52
  "hover labels are only available with interactive backends. To switch to an "
@@ -115,12 +115,18 @@ def plot_lm(
115
115
 
116
116
  if y_model is not None:
117
117
  _, _, _, y_model_plotters = y_model[i]
118
+
118
119
  if kind_model == "lines":
119
- for j in range(num_samples):
120
- ax_i.plot(x_plotters, y_model_plotters[..., j], **y_model_plot_kwargs)
121
- ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
120
+ # y_model_plotters should be (points, samples)
121
+ y_points = y_model_plotters.shape[0]
122
+ if x_plotters.shape[0] == y_points:
123
+ for j in range(num_samples):
124
+ ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
125
+
126
+ ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
127
+ y_model_mean = np.mean(y_model_plotters, axis=1)
128
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
122
129
 
123
- y_model_mean = np.mean(y_model_plotters, axis=1)
124
130
  else:
125
131
  plot_hdi(
126
132
  x_plotters,
@@ -128,10 +134,10 @@ def plot_lm(
128
134
  fill_kwargs=y_model_fill_kwargs,
129
135
  ax=ax_i,
130
136
  )
131
- ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
132
137
 
133
- y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
134
- ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
138
+ ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
139
+ y_model_mean = np.mean(y_model_plotters, axis=0)
140
+ ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
135
141
 
136
142
  if legend:
137
143
  ax_i.legend(fontsize=xt_labelsize, loc="upper left")
@@ -30,6 +30,8 @@ def plot_pair(
30
30
  diverging_mask,
31
31
  divergences_kwargs,
32
32
  flat_var_names,
33
+ flat_ref_slices,
34
+ flat_var_labels,
33
35
  backend_kwargs,
34
36
  marginal_kwargs,
35
37
  show,
@@ -77,24 +79,12 @@ def plot_pair(
77
79
  kde_kwargs["contour_kwargs"].setdefault("colors", "k")
78
80
 
79
81
  if reference_values:
80
- reference_values_copy = {}
81
- label = []
82
- for variable in list(reference_values.keys()):
83
- if " " in variable:
84
- variable_copy = variable.replace(" ", "\n", 1)
85
- else:
86
- variable_copy = variable
87
-
88
- label.append(variable_copy)
89
- reference_values_copy[variable_copy] = reference_values[variable]
90
-
91
- difference = set(flat_var_names).difference(set(label))
82
+ difference = set(flat_var_names).difference(set(reference_values.keys()))
92
83
 
93
84
  if difference:
94
- warn = [diff.replace("\n", " ", 1) for diff in difference]
95
85
  warnings.warn(
96
86
  "Argument reference_values does not include reference value for: {}".format(
97
- ", ".join(warn)
87
+ ", ".join(difference)
98
88
  ),
99
89
  UserWarning,
100
90
  )
@@ -211,12 +201,12 @@ def plot_pair(
211
201
 
212
202
  if reference_values:
213
203
  ax.plot(
214
- reference_values_copy[flat_var_names[0]],
215
- reference_values_copy[flat_var_names[1]],
204
+ np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
205
+ np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
216
206
  **reference_values_kwargs,
217
207
  )
218
- ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
219
- ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
208
+ ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
209
+ ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
220
210
  ax.tick_params(labelsize=xt_labelsize)
221
211
 
222
212
  else:
@@ -336,20 +326,22 @@ def plot_pair(
336
326
  y_name = flat_var_names[j + not_marginals]
337
327
  if (x_name not in difference) and (y_name not in difference):
338
328
  ax[j, i].plot(
339
- reference_values_copy[x_name],
340
- reference_values_copy[y_name],
329
+ np.array(reference_values[x_name])[flat_ref_slices[i]],
330
+ np.array(reference_values[y_name])[
331
+ flat_ref_slices[j + not_marginals]
332
+ ],
341
333
  **reference_values_kwargs,
342
334
  )
343
335
 
344
336
  if j != vars_to_plot - 1:
345
337
  plt.setp(ax[j, i].get_xticklabels(), visible=False)
346
338
  else:
347
- ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
339
+ ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
348
340
  if i != 0:
349
341
  plt.setp(ax[j, i].get_yticklabels(), visible=False)
350
342
  else:
351
343
  ax[j, i].set_ylabel(
352
- f"{flat_var_names[j + not_marginals]}",
344
+ f"{flat_var_labels[j + not_marginals]}",
353
345
  fontsize=ax_labelsize,
354
346
  wrap=True,
355
347
  )
arviz/plots/bfplot.py CHANGED
@@ -2,11 +2,9 @@
2
2
  # pylint: disable=unbalanced-tuple-unpacking
3
3
  import logging
4
4
 
5
- from numpy import interp
6
-
7
5
  from ..data.utils import extract
8
6
  from .plot_utils import get_plotting_function
9
- from ..stats.density_utils import _kde_linear
7
+ from ..stats import bayes_factor
10
8
 
11
9
  _log = logging.getLogger(__name__)
12
10
 
@@ -94,32 +92,17 @@ def plot_bf(
94
92
  >>> az.plot_bf(idata, var_name="a", ref_val=0)
95
93
 
96
94
  """
97
- posterior = extract(idata, var_names=var_name).values
98
-
99
- if ref_val > posterior.max() or ref_val < posterior.min():
100
- _log.warning(
101
- "The reference value is outside of the posterior. "
102
- "This translate into infinite support for H1, which is most likely an overstatement."
103
- )
104
-
105
- if posterior.ndim > 1:
106
- _log.warning("Posterior distribution has {posterior.ndim} dimensions")
107
95
 
108
96
  if prior is None:
109
97
  prior = extract(idata, var_names=var_name, group="prior").values
110
98
 
111
- if posterior.dtype.kind == "f":
112
- posterior_grid, posterior_pdf = _kde_linear(posterior)
113
- prior_grid, prior_pdf = _kde_linear(prior)
114
- posterior_at_ref_val = interp(ref_val, posterior_grid, posterior_pdf)
115
- prior_at_ref_val = interp(ref_val, prior_grid, prior_pdf)
116
-
117
- elif posterior.dtype.kind == "i":
118
- posterior_at_ref_val = (posterior == ref_val).mean()
119
- prior_at_ref_val = (prior == ref_val).mean()
99
+ bf, p_at_ref_val = bayes_factor(
100
+ idata, var_name, prior=prior, ref_val=ref_val, return_ref_vals=True
101
+ )
102
+ bf_10 = bf["BF10"]
103
+ bf_01 = bf["BF01"]
120
104
 
121
- bf_10 = prior_at_ref_val / posterior_at_ref_val
122
- bf_01 = 1 / bf_10
105
+ posterior = extract(idata, var_names=var_name)
123
106
 
124
107
  bfplot_kwargs = dict(
125
108
  ax=ax,
@@ -128,8 +111,8 @@ def plot_bf(
128
111
  prior=prior,
129
112
  posterior=posterior,
130
113
  ref_val=ref_val,
131
- prior_at_ref_val=prior_at_ref_val,
132
- posterior_at_ref_val=posterior_at_ref_val,
114
+ prior_at_ref_val=p_at_ref_val["prior"],
115
+ posterior_at_ref_val=p_at_ref_val["posterior"],
133
116
  var_name=var_name,
134
117
  colors=colors,
135
118
  figsize=figsize,
arviz/plots/bpvplot.py CHANGED
@@ -16,6 +16,7 @@ def plot_bpv(
16
16
  bpv=True,
17
17
  plot_mean=True,
18
18
  reference="analytical",
19
+ smoothing=None,
19
20
  mse=False,
20
21
  n_ref=100,
21
22
  hdi_prob=0.94,
@@ -72,6 +73,9 @@ def plot_bpv(
72
73
  reference : {"analytical", "samples", None}, default "analytical"
73
74
  How to compute the distributions used as reference for ``kind=u_values``
74
75
  or ``kind=p_values``. Use `None` to not plot any reference.
76
+ smoothing : bool, optional
77
+ If True and the data has integer dtype, smooth the data before computing the p-values,
78
+ u-values or tstat. By default, True when `kind` is "u_value" and False otherwise.
75
79
  mse : bool, default False
76
80
  Show scaled mean square error between uniform distribution and marginal p_value
77
81
  distribution.
@@ -166,7 +170,8 @@ def plot_bpv(
166
170
  Notes
167
171
  -----
168
172
  Discrete data is smoothed before computing either p-values or u-values using the
169
- function :func:`~arviz.smooth_data`
173
+ function :func:`~arviz.smooth_data` if the data is integer type
174
+ and the smoothing parameter is True.
170
175
 
171
176
  Examples
172
177
  --------
@@ -206,6 +211,9 @@ def plot_bpv(
206
211
  elif not 1 >= hdi_prob > 0:
207
212
  raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
208
213
 
214
+ if smoothing is None:
215
+ smoothing = kind.lower() == "u_value"
216
+
209
217
  if data_pairs is None:
210
218
  data_pairs = {}
211
219
 
@@ -291,6 +299,7 @@ def plot_bpv(
291
299
  plot_ref_kwargs=plot_ref_kwargs,
292
300
  backend_kwargs=backend_kwargs,
293
301
  show=show,
302
+ smoothing=smoothing,
294
303
  )
295
304
 
296
305
  # TODO: Add backend kwargs
arviz/plots/hdiplot.py CHANGED
@@ -136,6 +136,11 @@ def plot_hdi(
136
136
  x = np.asarray(x)
137
137
  x_shape = x.shape
138
138
 
139
+ if isinstance(x[0], str):
140
+ raise NotImplementedError(
141
+ "The `arviz.plot_hdi()` function does not support categorical data. "
142
+ "Consider using `arviz.plot_forest()`."
143
+ )
139
144
  if y is None and hdi_data is None:
140
145
  raise ValueError("One of {y, hdi_data} is required")
141
146
  if hdi_data is not None and y is not None:
arviz/plots/kdeplot.py CHANGED
@@ -255,6 +255,10 @@ def plot_kde(
255
255
  "or plot_pair instead of plot_kde"
256
256
  )
257
257
 
258
+ if backend is None:
259
+ backend = rcParams["plot.backend"]
260
+ backend = backend.lower()
261
+
258
262
  if values2 is None:
259
263
  if bw == "default":
260
264
  bw = "taylor" if is_circular else "experimental"
@@ -346,10 +350,6 @@ def plot_kde(
346
350
  **kwargs,
347
351
  )
348
352
 
349
- if backend is None:
350
- backend = rcParams["plot.backend"]
351
- backend = backend.lower()
352
-
353
353
  # TODO: Add backend kwargs
354
354
  plot = get_plotting_function("plot_kde", "kdeplot", backend)
355
355
  ax = plot(**kde_plot_args)