arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/inference_data.py +34 -7
  3. arviz/data/io_beanmachine.py +6 -1
  4. arviz/data/io_cmdstanpy.py +439 -50
  5. arviz/data/io_pyjags.py +5 -2
  6. arviz/data/io_pystan.py +1 -2
  7. arviz/labels.py +2 -0
  8. arviz/plots/backends/bokeh/bpvplot.py +7 -2
  9. arviz/plots/backends/bokeh/compareplot.py +7 -4
  10. arviz/plots/backends/bokeh/densityplot.py +0 -1
  11. arviz/plots/backends/bokeh/distplot.py +0 -2
  12. arviz/plots/backends/bokeh/forestplot.py +3 -5
  13. arviz/plots/backends/bokeh/kdeplot.py +0 -2
  14. arviz/plots/backends/bokeh/pairplot.py +0 -4
  15. arviz/plots/backends/matplotlib/bfplot.py +0 -1
  16. arviz/plots/backends/matplotlib/bpvplot.py +3 -3
  17. arviz/plots/backends/matplotlib/compareplot.py +1 -1
  18. arviz/plots/backends/matplotlib/dotplot.py +1 -1
  19. arviz/plots/backends/matplotlib/forestplot.py +2 -4
  20. arviz/plots/backends/matplotlib/kdeplot.py +0 -1
  21. arviz/plots/backends/matplotlib/khatplot.py +0 -1
  22. arviz/plots/backends/matplotlib/lmplot.py +4 -5
  23. arviz/plots/backends/matplotlib/pairplot.py +0 -1
  24. arviz/plots/backends/matplotlib/ppcplot.py +8 -5
  25. arviz/plots/backends/matplotlib/traceplot.py +1 -2
  26. arviz/plots/bfplot.py +7 -6
  27. arviz/plots/bpvplot.py +7 -2
  28. arviz/plots/compareplot.py +2 -2
  29. arviz/plots/ecdfplot.py +37 -112
  30. arviz/plots/elpdplot.py +1 -1
  31. arviz/plots/essplot.py +2 -2
  32. arviz/plots/kdeplot.py +0 -1
  33. arviz/plots/pairplot.py +1 -1
  34. arviz/plots/plot_utils.py +0 -1
  35. arviz/plots/ppcplot.py +51 -45
  36. arviz/plots/separationplot.py +0 -1
  37. arviz/stats/__init__.py +2 -0
  38. arviz/stats/density_utils.py +2 -2
  39. arviz/stats/diagnostics.py +2 -3
  40. arviz/stats/ecdf_utils.py +165 -0
  41. arviz/stats/stats.py +241 -38
  42. arviz/stats/stats_utils.py +36 -7
  43. arviz/tests/base_tests/test_data.py +73 -5
  44. arviz/tests/base_tests/test_plots_bokeh.py +0 -1
  45. arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
  46. arviz/tests/base_tests/test_stats.py +43 -1
  47. arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
  48. arviz/tests/base_tests/test_stats_utils.py +3 -3
  49. arviz/tests/external_tests/test_data_beanmachine.py +2 -0
  50. arviz/tests/external_tests/test_data_numpyro.py +3 -3
  51. arviz/tests/external_tests/test_data_pyjags.py +3 -1
  52. arviz/tests/external_tests/test_data_pyro.py +3 -3
  53. arviz/tests/helpers.py +8 -8
  54. arviz/utils.py +15 -7
  55. arviz/wrappers/wrap_pymc.py +1 -1
  56. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
  57. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
  58. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
  59. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
  60. {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
@@ -69,13 +69,14 @@ def plot_compare(
69
69
  err_ys.append((y, y))
70
70
 
71
71
  # plot them
72
- dif_tri = ax.triangle(
72
+ dif_tri = ax.scatter(
73
73
  comp_df[information_criterion].iloc[1:],
74
74
  yticks_pos[1::2],
75
75
  line_color=plot_kwargs.get("color_dse", "grey"),
76
76
  fill_color=plot_kwargs.get("color_dse", "grey"),
77
77
  line_width=2,
78
78
  size=6,
79
+ marker="triangle",
79
80
  )
80
81
  dif_line = ax.multi_line(err_xs, err_ys, line_color=plot_kwargs.get("color_dse", "grey"))
81
82
 
@@ -85,13 +86,14 @@ def plot_compare(
85
86
  ax.yaxis.ticker = yticks_pos[::2]
86
87
  ax.yaxis.major_label_overrides = dict(zip(yticks_pos[::2], yticks_labels))
87
88
 
88
- elpd_circ = ax.circle(
89
+ elpd_circ = ax.scatter(
89
90
  comp_df[information_criterion],
90
91
  yticks_pos[::2],
91
92
  line_color=plot_kwargs.get("color_ic", "black"),
92
93
  fill_color=None,
93
94
  line_width=2,
94
95
  size=6,
96
+ marker="circle",
95
97
  )
96
98
  elpd_label = [elpd_circ]
97
99
 
@@ -110,7 +112,7 @@ def plot_compare(
110
112
 
111
113
  labels.append(("ELPD", elpd_label))
112
114
 
113
- scale = comp_df["scale"][0]
115
+ scale = comp_df["scale"].iloc[0]
114
116
 
115
117
  if insample_dev:
116
118
  p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
@@ -120,13 +122,14 @@ def plot_compare(
120
122
  correction = -p_ic
121
123
  elif scale == "deviance":
122
124
  correction = -(2 * p_ic)
123
- insample_circ = ax.circle(
125
+ insample_circ = ax.scatter(
124
126
  comp_df[information_criterion] + correction,
125
127
  yticks_pos[::2],
126
128
  line_color=plot_kwargs.get("color_insample_dev", "black"),
127
129
  fill_color=plot_kwargs.get("color_insample_dev", "black"),
128
130
  line_width=2,
129
131
  size=6,
132
+ marker="circle",
130
133
  )
131
134
  labels.append(("In-sample ELPD", [insample_circ]))
132
135
 
@@ -129,7 +129,6 @@ def _d_helper(
129
129
  shade,
130
130
  ax,
131
131
  ):
132
-
133
132
  extra = {}
134
133
  plotted = []
135
134
 
@@ -145,12 +145,10 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs, is_circular):
145
145
  edges = edges.astype(float) - 0.5
146
146
 
147
147
  if is_circular:
148
-
149
148
  if is_circular == "degrees":
150
149
  edges = np.deg2rad(edges)
151
150
  labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"]
152
151
  else:
153
-
154
152
  labels = [
155
153
  r"0",
156
154
  r"π/4",
@@ -15,7 +15,6 @@ from ....rcparams import rcParams
15
15
  from ....stats import hdi
16
16
  from ....stats.density_utils import get_bins, histogram, kde
17
17
  from ....stats.diagnostics import _ess, _rhat
18
- from ....utils import conditional_jit
19
18
  from ...plot_utils import _scale_fig_size
20
19
  from .. import show_layout
21
20
  from . import backend_kwarg_defaults
@@ -277,7 +276,6 @@ class PlotHandler:
277
276
  """Collect labels and ticks from plotters."""
278
277
  val = self.plotters.values()
279
278
 
280
- @conditional_jit(forceobj=True, nopython=False)
281
279
  def label_idxs():
282
280
  labels, idxs = [], []
283
281
  for plotter in val:
@@ -299,7 +297,7 @@ class PlotHandler:
299
297
  def legend(self, ax, plotted):
300
298
  """Add interactive legend with colorcoded model info."""
301
299
  legend_it = []
302
- for (model_name, glyphs) in plotted.items():
300
+ for model_name, glyphs in plotted.items():
303
301
  legend_it.append((model_name, glyphs))
304
302
 
305
303
  legend = Legend(items=legend_it, orientation="vertical", location="top_left")
@@ -640,7 +638,7 @@ class VarHandler:
640
638
  grouped_data = [[(0, datum)] for datum in self.data]
641
639
  skip_dims = self.combine_dims.union({"chain"})
642
640
  else:
643
- grouped_data = [datum.groupby("chain") for datum in self.data]
641
+ grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
644
642
  skip_dims = self.combine_dims
645
643
 
646
644
  label_dict = OrderedDict()
@@ -648,7 +646,7 @@ class VarHandler:
648
646
  for name, grouped_datum in zip(self.model_names, grouped_data):
649
647
  for _, sub_data in grouped_datum:
650
648
  datum_iter = xarray_var_iter(
651
- sub_data,
649
+ sub_data.squeeze(),
652
650
  var_names=[self.var_name],
653
651
  skip_dims=skip_dims,
654
652
  reverse_selections=True,
@@ -165,7 +165,6 @@ def plot_kde(
165
165
  x_x, y_y = np.mgrid[xmin:xmax:g_s, ymin:ymax:g_s]
166
166
 
167
167
  if contour:
168
-
169
168
  scaled_density, *scaled_density_args = _scale_axis(density)
170
169
 
171
170
  contourpy_kwargs = _init_kwargs_dict(contour_kwargs.pop("contourpy_kwargs", {}))
@@ -224,7 +223,6 @@ def plot_kde(
224
223
  ax.ygrid.grid_line_color = None
225
224
 
226
225
  else:
227
-
228
226
  cmap = pcolormesh_kwargs.pop("cmap", "viridis")
229
227
  if isinstance(cmap, str):
230
228
  cmap = get_cmap(cmap)
@@ -241,11 +241,9 @@ def plot_pair(
241
241
 
242
242
  # pylint: disable=too-many-nested-blocks
243
243
  for i in range(0, numvars - marginals_offset):
244
-
245
244
  var1 = flat_var_names[i] if tmp_flat_var_names is None else tmp_flat_var_names[i]
246
245
 
247
246
  for j in range(0, numvars - marginals_offset):
248
-
249
247
  var2 = (
250
248
  flat_var_names[j + marginals_offset]
251
249
  if tmp_flat_var_names is None
@@ -268,7 +266,6 @@ def plot_pair(
268
266
  ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
269
267
 
270
268
  elif j + marginals_offset > i:
271
-
272
269
  if "scatter" in kind:
273
270
  if divergences:
274
271
  ax[j, i].circle(var1, var2, source=source, view=source_nondiv)
@@ -328,7 +325,6 @@ def plot_pair(
328
325
  ax[j, i].add_layout(ax_vline)
329
326
 
330
327
  if marginals:
331
-
332
328
  ax[j - 1, i].add_layout(ax_vline)
333
329
 
334
330
  pe_last = calculate_point_estimate(point_estimate, plotters[-1][-1])
@@ -23,7 +23,6 @@ def plot_bf(
23
23
  backend_kwargs,
24
24
  show,
25
25
  ):
26
-
27
26
  """Matplotlib Bayes Factor plot."""
28
27
  if backend_kwargs is None:
29
28
  backend_kwargs = {}
@@ -86,6 +86,9 @@ def plot_bpv(
86
86
  obs_vals = obs_vals.flatten()
87
87
  pp_vals = pp_vals.reshape(total_pp_samples, -1)
88
88
 
89
+ if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
90
+ obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
91
+
89
92
  if kind == "p_value":
90
93
  tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
91
94
  x_s, tstat_pit_dens = kde(tstat_pit)
@@ -110,9 +113,6 @@ def plot_bpv(
110
113
  ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
111
114
 
112
115
  elif kind == "u_value":
113
- if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
114
- obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
115
-
116
116
  tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
117
117
  x_s, tstat_pit_dens = kde(tstat_pit)
118
118
  ax_i.plot(x_s, tstat_pit_dens, color=color)
@@ -84,7 +84,7 @@ def plot_compare(
84
84
  else:
85
85
  ax.set_yticks(yticks_pos[::2])
86
86
 
87
- scale = comp_df["scale"][0]
87
+ scale = comp_df["scale"].iloc[0]
88
88
 
89
89
  if insample_dev:
90
90
  p_ic = comp_df[f"p_{information_criterion.split('_')[1]}"]
@@ -97,7 +97,7 @@ def plot_dot(
97
97
  stack_locs, stack_count = wilkinson_algorithm(values, binwidth)
98
98
  x, y = layout_stacks(stack_locs, stack_count, binwidth, stackratio, rotated)
99
99
 
100
- for (x_i, y_i) in zip(x, y):
100
+ for x_i, y_i in zip(x, y):
101
101
  dot = plt.Circle((x_i, y_i), dotsize * binwidth / 2, **plot_kwargs)
102
102
  ax.add_patch(dot)
103
103
 
@@ -11,7 +11,6 @@ from ....stats import hdi
11
11
  from ....stats.density_utils import get_bins, histogram, kde
12
12
  from ....stats.diagnostics import _ess, _rhat
13
13
  from ....sel_utils import xarray_var_iter
14
- from ....utils import conditional_jit
15
14
  from ...plot_utils import _scale_fig_size
16
15
  from . import backend_kwarg_defaults, backend_show
17
16
 
@@ -236,7 +235,6 @@ class PlotHandler:
236
235
  """Collect labels and ticks from plotters."""
237
236
  val = self.plotters.values()
238
237
 
239
- @conditional_jit(forceobj=True, nopython=False)
240
238
  def label_idxs():
241
239
  labels, idxs = [], []
242
240
  for plotter in val:
@@ -536,7 +534,7 @@ class VarHandler:
536
534
  grouped_data = [[(0, datum)] for datum in self.data]
537
535
  skip_dims = self.combine_dims.union({"chain"})
538
536
  else:
539
- grouped_data = [datum.groupby("chain") for datum in self.data]
537
+ grouped_data = [datum.groupby("chain", squeeze=False) for datum in self.data]
540
538
  skip_dims = self.combine_dims
541
539
 
542
540
  label_dict = OrderedDict()
@@ -544,7 +542,7 @@ class VarHandler:
544
542
  for name, grouped_datum in zip(self.model_names, grouped_data):
545
543
  for _, sub_data in grouped_datum:
546
544
  datum_iter = xarray_var_iter(
547
- sub_data,
545
+ sub_data.squeeze(),
548
546
  var_names=[self.var_name],
549
547
  skip_dims=skip_dims,
550
548
  reverse_selections=True,
@@ -88,7 +88,6 @@ def plot_kde(
88
88
  rug_space = max(density) * rug_kwargs.pop("space")
89
89
 
90
90
  if is_circular:
91
-
92
91
  if is_circular == "radians":
93
92
  labels = [
94
93
  "0",
@@ -201,7 +201,6 @@ def _make_hover_annotation(fig, ax, sc_plot, coord_labels, rgba_c, hover_format)
201
201
  offset = 10
202
202
 
203
203
  def update_annot(ind):
204
-
205
204
  idx = ind["ind"][0]
206
205
  pos = sc_plot.get_offsets()[idx]
207
206
  annot_text = hover_format.format(idx, coord_labels[idx])
@@ -50,7 +50,6 @@ def plot_lm(
50
50
  _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
51
51
 
52
52
  for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
53
-
54
53
  # All the kwargs are defined here beforehand
55
54
  y_kwargs = matplotlib_kwarg_dealiaser(y_kwargs, "plot")
56
55
  y_kwargs.setdefault("color", "C3")
@@ -68,22 +67,22 @@ def plot_lm(
68
67
  y_hat_plot_kwargs.setdefault("linewidth", 0)
69
68
 
70
69
  y_hat_fill_kwargs = matplotlib_kwarg_dealiaser(y_hat_fill_kwargs, "fill_between")
71
- y_hat_fill_kwargs.setdefault("color", "C1")
70
+ y_hat_fill_kwargs.setdefault("color", "C3")
72
71
 
73
72
  y_model_plot_kwargs = matplotlib_kwarg_dealiaser(y_model_plot_kwargs, "plot")
74
- y_model_plot_kwargs.setdefault("color", "k")
73
+ y_model_plot_kwargs.setdefault("color", "C6")
75
74
  y_model_plot_kwargs.setdefault("alpha", 0.5)
76
75
  y_model_plot_kwargs.setdefault("linewidth", 0.5)
77
76
  y_model_plot_kwargs.setdefault("zorder", 9)
78
77
 
79
78
  y_model_fill_kwargs = matplotlib_kwarg_dealiaser(y_model_fill_kwargs, "fill_between")
80
- y_model_fill_kwargs.setdefault("color", "k")
79
+ y_model_fill_kwargs.setdefault("color", "C0")
81
80
  y_model_fill_kwargs.setdefault("linewidth", 0.5)
82
81
  y_model_fill_kwargs.setdefault("zorder", 9)
83
82
  y_model_fill_kwargs.setdefault("alpha", 0.5)
84
83
 
85
84
  y_model_mean_kwargs = matplotlib_kwarg_dealiaser(y_model_mean_kwargs, "plot")
86
- y_model_mean_kwargs.setdefault("color", "y")
85
+ y_model_mean_kwargs.setdefault("color", "C6")
87
86
  y_model_mean_kwargs.setdefault("linewidth", 0.8)
88
87
  y_model_mean_kwargs.setdefault("zorder", 11)
89
88
 
@@ -291,7 +291,6 @@ def plot_pair(
291
291
  ax[j, i].scatter(var1, var2, **scatter_kwargs)
292
292
 
293
293
  if "kde" in kind:
294
-
295
294
  plot_kde(
296
295
  var1,
297
296
  var2,
@@ -371,8 +371,6 @@ def plot_ppc(
371
371
  if legend:
372
372
  if i == 0:
373
373
  ax_i.legend(fontsize=xt_labelsize * 0.75)
374
- else:
375
- ax_i.legend([])
376
374
 
377
375
  if backend_show(show):
378
376
  plt.show()
@@ -414,15 +412,20 @@ def _set_animation(
414
412
 
415
413
  else:
416
414
  vals = pp_sampled_vals[0]
417
- _, y_vals, x_vals = histogram(vals, bins="auto")
415
+ bins = get_bins(vals)
416
+ _, y_vals, x_vals = histogram(vals, bins=bins)
418
417
  (line,) = ax.plot(x_vals[:-1], y_vals, **plot_kwargs)
419
418
 
420
- max_max = max(max(histogram(pp_sampled_vals[i], bins="auto")[1]) for i in range(length))
419
+ max_max = max(
420
+ max(histogram(pp_sampled_vals[i], bins=get_bins(pp_sampled_vals[i]))[1])
421
+ for i in range(length)
422
+ )
421
423
 
422
424
  ax.set_ylim(0, max_max)
423
425
 
424
426
  def animate(i):
425
- _, y_vals, x_vals = histogram(pp_sampled_vals[i], bins="auto")
427
+ pp_vals = pp_sampled_vals[i]
428
+ _, y_vals, x_vals = histogram(pp_vals, bins=get_bins(pp_vals))
426
429
  line.set_data(x_vals[:-1], y_vals)
427
430
  return (line,)
428
431
 
@@ -430,7 +430,7 @@ def plot_trace(
430
430
  Line2D(
431
431
  [], [], label=chain_id, **dealiase_sel_kwargs(legend_kwargs, chain_prop, chain_id)
432
432
  )
433
- for chain_id in range(data.dims["chain"])
433
+ for chain_id in range(data.sizes["chain"])
434
434
  ]
435
435
  if combined:
436
436
  handles.insert(
@@ -470,7 +470,6 @@ def _plot_chains_mpl(
470
470
  circ_var_units,
471
471
  circ_units_trace,
472
472
  ):
473
-
474
473
  if not circular:
475
474
  circ_var_units = False
476
475
 
arviz/plots/bfplot.py CHANGED
@@ -38,7 +38,7 @@ def plot_bf(
38
38
  algorithm presented in [1]_.
39
39
 
40
40
  Parameters
41
- -----------
41
+ ----------
42
42
  idata : InferenceData
43
43
  Any object that can be converted to an :class:`arviz.InferenceData` object
44
44
  Refer to documentation of :func:`arviz.convert_to_dataset` for details.
@@ -52,16 +52,16 @@ def plot_bf(
52
52
  Tuple of valid Matplotlib colors. First element for the prior, second for the posterior.
53
53
  figsize : (float, float), optional
54
54
  Figure size. If `None` it will be defined automatically.
55
- textsize: float, optional
55
+ textsize : float, optional
56
56
  Text size scaling factor for labels, titles and lines. If `None` it will be auto
57
57
  scaled based on `figsize`.
58
- plot_kwargs : dicts, optional
58
+ plot_kwargs : dict, optional
59
59
  Additional keywords passed to :func:`matplotlib.pyplot.plot`.
60
- hist_kwargs : dicts, optional
60
+ hist_kwargs : dict, optional
61
61
  Additional keywords passed to :func:`arviz.plot_dist`. Only works for discrete variables.
62
62
  ax : axes, optional
63
63
  :class:`matplotlib.axes.Axes` or :class:`bokeh.plotting.Figure`.
64
- backend :{"matplotlib", "bokeh"}, default "matplotlib"
64
+ backend : {"matplotlib", "bokeh"}, default "matplotlib"
65
65
  Select plotting backend.
66
66
  backend_kwargs : dict, optional
67
67
  These are kwargs specific to the backend being used, passed to
@@ -78,7 +78,7 @@ def plot_bf(
78
78
  References
79
79
  ----------
80
80
  .. [1] Heck, D., 2019. A caveat on the Savage-Dickey density ratio:
81
- The case of computing Bayes factors for regression parameters.
81
+ The case of computing Bayes factors for regression parameters.
82
82
 
83
83
  Examples
84
84
  --------
@@ -92,6 +92,7 @@ def plot_bf(
92
92
  >>> idata = az.from_dict(posterior={"a":np.random.normal(1, 0.5, 5000)},
93
93
  ... prior={"a":np.random.normal(0, 1, 5000)})
94
94
  >>> az.plot_bf(idata, var_name="a", ref_val=0)
95
+
95
96
  """
96
97
  posterior = extract(idata, var_names=var_name).values
97
98
 
arviz/plots/bpvplot.py CHANGED
@@ -162,6 +162,11 @@ def plot_bpv(
162
162
  ----------
163
163
  * Gelman et al. (2013) see http://www.stat.columbia.edu/~gelman/book/ pages 151-153 for details
164
164
 
165
+ Notes
166
+ -----
167
+ Discrete data is smoothed before computing either p-values or u-values using the
168
+ function :func:`~arviz.smooth_data`
169
+
165
170
  Examples
166
171
  --------
167
172
  Plot Bayesian p_values.
@@ -225,11 +230,11 @@ def plot_bpv(
225
230
 
226
231
  if flatten_pp is None:
227
232
  if flatten is None:
228
- flatten_pp = list(predictive_dataset.dims.keys())
233
+ flatten_pp = list(predictive_dataset.dims)
229
234
  else:
230
235
  flatten_pp = flatten
231
236
  if flatten is None:
232
- flatten = list(observed.dims.keys())
237
+ flatten = list(observed.dims)
233
238
 
234
239
  if coords is None:
235
240
  coords = {}
@@ -90,10 +90,10 @@ def plot_compare(
90
90
  References
91
91
  ----------
92
92
  .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
93
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
93
+ cross-validation and WAIC https://arxiv.org/abs/1507.04544
94
94
 
95
95
  .. [2] McElreath R. (2022). Statistical Rethinking A Bayesian Course with Examples in
96
- R and Stan, Second edition, CRC Press.
96
+ R and Stan, Second edition, CRC Press.
97
97
 
98
98
  Examples
99
99
  --------
arviz/plots/ecdfplot.py CHANGED
@@ -1,8 +1,9 @@
1
1
  """Plot ecdf or ecdf-difference plot with confidence bands."""
2
2
  import numpy as np
3
- from scipy.stats import uniform, binom
3
+ from scipy.stats import uniform
4
4
 
5
5
  from ..rcparams import rcParams
6
+ from ..stats.ecdf_utils import compute_ecdf, ecdf_confidence_band, _get_ecdf_points
6
7
  from .plot_utils import get_plotting_function
7
8
 
8
9
 
@@ -26,7 +27,7 @@ def plot_ecdf(
26
27
  show=None,
27
28
  backend=None,
28
29
  backend_kwargs=None,
29
- **kwargs
30
+ **kwargs,
30
31
  ):
31
32
  r"""Plot ECDF or ECDF-Difference Plot with Confidence bands.
32
33
 
@@ -48,6 +49,7 @@ def plot_ecdf(
48
49
  Values to compare to the original sample.
49
50
  cdf : callable, optional
50
51
  Cumulative distribution function of the distribution to compare the original sample.
52
+ The function must take as input a numpy array of draws from the distribution.
51
53
  difference : bool, default False
52
54
  If True then plot ECDF-difference plot otherwise ECDF plot.
53
55
  pit : bool, default False
@@ -180,75 +182,47 @@ def plot_ecdf(
180
182
  values = np.ravel(values)
181
183
  values.sort()
182
184
 
183
- ## This block computes gamma and uses it to get the upper and lower confidence bands
184
- ## Here we check if we want confidence bands or not
185
- if confidence_bands:
186
- ## If plotting PIT then we find the PIT values of sample.
187
- ## Basically here we generate the evaluation points(x) and find the PIT values.
188
- ## z is the evaluation point for our uniform distribution in compute_gamma()
189
- if pit:
190
- x = np.linspace(1 / npoints, 1, npoints)
191
- z = x
192
- ## Finding PIT for our sample
193
- probs = cdf(values) if cdf else compute_ecdf(values2, values) / len(values2)
194
- else:
195
- ## If not PIT use sample for plots and for evaluation points(x) use equally spaced
196
- ## points between minimum and maximum of sample
197
- ## For z we have used cdf(x)
198
- x = np.linspace(values[0], values[-1], npoints)
199
- z = cdf(x) if cdf else compute_ecdf(values2, x)
200
- probs = values
201
-
202
- n = len(values) # number of samples
203
- ## Computing gamma
204
- gamma = fpr if pointwise else compute_gamma(n, z, npoints, num_trials, fpr)
205
- ## Using gamma to get the confidence intervals
206
- lower, higher = get_lims(gamma, n, z)
207
-
208
- ## This block is for whether to plot ECDF or ECDF-difference
209
- if not difference:
210
- ## We store the coordinates of our ecdf in x_coord, y_coord
211
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
185
+ if pit:
186
+ eval_points = np.linspace(1 / npoints, 1, npoints)
187
+ if cdf:
188
+ sample = cdf(values)
212
189
  else:
213
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
214
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
215
- for i, x_i in enumerate(x):
216
- y_coord[i] = y_coord[i] - (
217
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
218
- )
219
-
220
- ## Similarly we subtract from the upper and lower bounds
221
- if pit:
222
- lower = lower - x
223
- higher = higher - x
224
- else:
225
- lower = lower - (cdf(x) if cdf else compute_ecdf(values2, x))
226
- higher = higher - (cdf(x) if cdf else compute_ecdf(values2, x))
227
-
190
+ sample = compute_ecdf(values2, values) / len(values2)
191
+ cdf_at_eval_points = eval_points
192
+ rvs = uniform(0, 1).rvs
228
193
  else:
229
- if pit:
230
- x = np.linspace(1 / npoints, 1, npoints)
231
- probs = cdf(values)
194
+ eval_points = np.linspace(values[0], values[-1], npoints)
195
+ sample = values
196
+ if confidence_bands or difference:
197
+ if cdf:
198
+ cdf_at_eval_points = cdf(eval_points)
199
+ else:
200
+ cdf_at_eval_points = compute_ecdf(values2, eval_points)
232
201
  else:
233
- x = np.linspace(values[0], values[-1], npoints)
234
- probs = values
202
+ cdf_at_eval_points = np.zeros_like(eval_points)
203
+ rvs = None
204
+
205
+ x_coord, y_coord = _get_ecdf_points(sample, eval_points, difference)
235
206
 
207
+ if difference:
208
+ y_coord -= cdf_at_eval_points
209
+
210
+ if confidence_bands:
211
+ ndraws = len(values)
212
+ band_kwargs = {"prob": 1 - fpr, "num_trials": num_trials, "rvs": rvs, "random_state": None}
213
+ band_kwargs["method"] = "pointwise" if pointwise else "simulated"
214
+ lower, higher = ecdf_confidence_band(ndraws, eval_points, cdf_at_eval_points, **band_kwargs)
215
+
216
+ if difference:
217
+ lower -= cdf_at_eval_points
218
+ higher -= cdf_at_eval_points
219
+ else:
236
220
  lower, higher = None, None
237
- ## This block is for whether to plot ECDF or ECDF-difference
238
- if not difference:
239
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
240
- else:
241
- ## Here we subtract the ecdf value as here we are plotting the ECDF-difference
242
- x_coord, y_coord = get_ecdf_points(x, probs, difference)
243
- for i, x_i in enumerate(x):
244
- y_coord[i] = y_coord[i] - (
245
- x_i if pit else cdf(x_i) if cdf else compute_ecdf(values2, x_i)
246
- )
247
221
 
248
222
  ecdf_plot_args = dict(
249
223
  x_coord=x_coord,
250
224
  y_coord=y_coord,
251
- x_bands=x,
225
+ x_bands=eval_points,
252
226
  lower=lower,
253
227
  higher=higher,
254
228
  confidence_bands=confidence_bands,
@@ -260,7 +234,7 @@ def plot_ecdf(
260
234
  ax=ax,
261
235
  show=show,
262
236
  backend_kwargs=backend_kwargs,
263
- **kwargs
237
+ **kwargs,
264
238
  )
265
239
 
266
240
  if backend is None:
@@ -271,52 +245,3 @@ def plot_ecdf(
271
245
  ax = plot(**ecdf_plot_args)
272
246
 
273
247
  return ax
274
-
275
-
276
- def compute_ecdf(sample, z):
277
- """Compute ECDF.
278
-
279
- This function computes the ecdf value at the evaluation point
280
- or a sorted set of evaluation points.
281
- """
282
- return np.searchsorted(sample, z, side="right") / len(sample)
283
-
284
-
285
- def get_ecdf_points(x, probs, difference):
286
- """Compute the coordinates for the ecdf points using compute_ecdf."""
287
- y = compute_ecdf(probs, x)
288
-
289
- if not difference:
290
- x = np.insert(x, 0, x[0])
291
- y = np.insert(y, 0, 0)
292
- return x, y
293
-
294
-
295
- def compute_gamma(n, z, npoints=None, num_trials=1000, fpr=0.05):
296
- """Compute gamma for confidence interval calculation.
297
-
298
- This function simulates an adjusted value of gamma to account for multiplicity
299
- when forming an 1-fpr level confidence envelope for the ECDF of a sample.
300
- """
301
- if npoints is None:
302
- npoints = n
303
- gamma = []
304
- for _ in range(num_trials):
305
- unif_samples = uniform.rvs(0, 1, n)
306
- unif_samples = np.sort(unif_samples)
307
- gamma_m = 1000
308
- ## Can compute ecdf for all the z together or one at a time.
309
- f_z = compute_ecdf(unif_samples, z)
310
- f_z = compute_ecdf(unif_samples, z)
311
- gamma_m = 2 * min(
312
- np.amin(binom.cdf(n * f_z, n, z)), np.amin(1 - binom.cdf(n * f_z - 1, n, z))
313
- )
314
- gamma.append(gamma_m)
315
- return np.quantile(gamma, fpr)
316
-
317
-
318
- def get_lims(gamma, n, z):
319
- """Compute the simultaneous 1 - fpr level confidence bands."""
320
- lower = binom.ppf(gamma / 2, n, z)
321
- upper = binom.ppf(1 - gamma / 2, n, z)
322
- return lower / n, upper / n
arviz/plots/elpdplot.py CHANGED
@@ -98,7 +98,7 @@ def plot_elpd(
98
98
  References
99
99
  ----------
100
100
  .. [1] Vehtari et al. (2016). Practical Bayesian model evaluation using leave-one-out
101
- cross-validation and WAIC https://arxiv.org/abs/1507.04544
101
+ cross-validation and WAIC https://arxiv.org/abs/1507.04544
102
102
 
103
103
  Examples
104
104
  --------
arviz/plots/essplot.py CHANGED
@@ -202,8 +202,8 @@ def plot_ess(
202
202
 
203
203
  data = get_coords(convert_to_dataset(idata, group="posterior"), coords)
204
204
  var_names = _var_names(var_names, data, filter_vars)
205
- n_draws = data.dims["draw"]
206
- n_samples = n_draws * data.dims["chain"]
205
+ n_draws = data.sizes["draw"]
206
+ n_samples = n_draws * data.sizes["chain"]
207
207
 
208
208
  ess_tail_dataset = None
209
209
  mean_ess = None