arviz 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. arviz/__init__.py +1 -1
  2. arviz/data/base.py +20 -9
  3. arviz/data/converters.py +7 -3
  4. arviz/data/inference_data.py +28 -7
  5. arviz/plots/backends/__init__.py +8 -7
  6. arviz/plots/backends/bokeh/bpvplot.py +4 -3
  7. arviz/plots/backends/bokeh/densityplot.py +5 -1
  8. arviz/plots/backends/bokeh/dotplot.py +5 -2
  9. arviz/plots/backends/bokeh/essplot.py +4 -2
  10. arviz/plots/backends/bokeh/forestplot.py +11 -4
  11. arviz/plots/backends/bokeh/khatplot.py +4 -2
  12. arviz/plots/backends/bokeh/lmplot.py +9 -3
  13. arviz/plots/backends/bokeh/mcseplot.py +2 -2
  14. arviz/plots/backends/bokeh/pairplot.py +10 -5
  15. arviz/plots/backends/bokeh/ppcplot.py +2 -1
  16. arviz/plots/backends/bokeh/rankplot.py +2 -1
  17. arviz/plots/backends/bokeh/traceplot.py +2 -1
  18. arviz/plots/backends/bokeh/violinplot.py +2 -1
  19. arviz/plots/backends/matplotlib/bpvplot.py +2 -1
  20. arviz/plots/bfplot.py +9 -26
  21. arviz/plots/bpvplot.py +10 -1
  22. arviz/plots/compareplot.py +4 -4
  23. arviz/plots/ecdfplot.py +16 -8
  24. arviz/plots/forestplot.py +2 -2
  25. arviz/plots/hdiplot.py +5 -0
  26. arviz/plots/kdeplot.py +9 -2
  27. arviz/plots/plot_utils.py +5 -3
  28. arviz/preview.py +36 -5
  29. arviz/stats/__init__.py +1 -0
  30. arviz/stats/diagnostics.py +18 -14
  31. arviz/stats/ecdf_utils.py +157 -2
  32. arviz/stats/stats.py +99 -7
  33. arviz/tests/base_tests/test_data.py +41 -7
  34. arviz/tests/base_tests/test_diagnostics.py +5 -4
  35. arviz/tests/base_tests/test_plots_matplotlib.py +32 -13
  36. arviz/tests/base_tests/test_stats.py +11 -0
  37. arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
  38. arviz/utils.py +4 -0
  39. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/METADATA +22 -22
  40. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/RECORD +43 -43
  41. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/WHEEL +1 -1
  42. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/LICENSE +0 -0
  43. {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # pylint: disable=wildcard-import,invalid-name,wrong-import-position
2
2
  """ArviZ is a library for exploratory analysis of Bayesian models."""
3
- __version__ = "0.19.0"
3
+ __version__ = "0.21.0"
4
4
 
5
5
  import logging
6
6
  import os
arviz/data/base.py CHANGED
@@ -9,9 +9,13 @@ from copy import deepcopy
9
9
  from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
10
10
 
11
11
  import numpy as np
12
- import tree
13
12
  import xarray as xr
14
13
 
14
+ try:
15
+ import tree
16
+ except ImportError:
17
+ tree = None
18
+
15
19
  try:
16
20
  import ujson as json
17
21
  except ImportError:
@@ -89,6 +93,9 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
89
93
  input_tree.
90
94
  """
91
95
  # pylint: disable=protected-access
96
+ if tree is None:
97
+ raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
98
+
92
99
  if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
93
100
  isinstance(shallow_tree, tree.collections_abc.Mapping)
94
101
  or tree._is_namedtuple(shallow_tree)
@@ -194,10 +201,10 @@ def generate_dims_coords(
194
201
  for i, dim_len in enumerate(shape):
195
202
  idx = i + len([dim for dim in default_dims if dim in dims])
196
203
  if len(dims) < idx + 1:
197
- dim_name = f"{var_name}_dim_{idx}"
204
+ dim_name = f"{var_name}_dim_{i}"
198
205
  dims.append(dim_name)
199
206
  elif dims[idx] is None:
200
- dim_name = f"{var_name}_dim_{idx}"
207
+ dim_name = f"{var_name}_dim_{i}"
201
208
  dims[idx] = dim_name
202
209
  dim_name = dims[idx]
203
210
  if dim_name not in coords:
@@ -299,7 +306,7 @@ def numpy_to_data_array(
299
306
  return xr.DataArray(ary, coords=coords, dims=dims)
300
307
 
301
308
 
302
- def pytree_to_dataset(
309
+ def dict_to_dataset(
303
310
  data,
304
311
  *,
305
312
  attrs=None,
@@ -312,6 +319,8 @@ def pytree_to_dataset(
312
319
  ):
313
320
  """Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
314
321
 
322
+ ArviZ itself supports conversion of flat dictionaries.
323
+ Suport for pytrees requires ``dm-tree`` which is an optional dependency.
315
324
  See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
316
325
  this inclues at least dictionaries and tuple types.
317
326
 
@@ -386,10 +395,12 @@ def pytree_to_dataset(
386
395
  """
387
396
  if dims is None:
388
397
  dims = {}
389
- try:
390
- data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
391
- except TypeError: # probably unsortable keys -- the function will still work if
392
- pass # it is an honest dictionary.
398
+
399
+ if tree is not None:
400
+ try:
401
+ data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
402
+ except TypeError: # probably unsortable keys -- the function will still work if
403
+ pass # it is an honest dictionary.
393
404
 
394
405
  data_vars = {
395
406
  key: numpy_to_data_array(
@@ -406,7 +417,7 @@ def pytree_to_dataset(
406
417
  return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
407
418
 
408
419
 
409
- dict_to_dataset = pytree_to_dataset
420
+ pytree_to_dataset = dict_to_dataset
410
421
 
411
422
 
412
423
  def make_attrs(attrs=None, library=None):
arviz/data/converters.py CHANGED
@@ -1,9 +1,13 @@
1
1
  """High level conversion functions."""
2
2
 
3
3
  import numpy as np
4
- import tree
5
4
  import xarray as xr
6
5
 
6
+ try:
7
+ from tree import is_nested
8
+ except ImportError:
9
+ is_nested = lambda obj: False
10
+
7
11
  from .base import dict_to_dataset
8
12
  from .inference_data import InferenceData
9
13
  from .io_beanmachine import from_beanmachine
@@ -107,7 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
107
111
  dataset = obj.to_dataset()
108
112
  elif isinstance(obj, dict):
109
113
  dataset = dict_to_dataset(obj, coords=coords, dims=dims)
110
- elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
114
+ elif is_nested(obj) and not isinstance(obj, (list, tuple)):
111
115
  dataset = dict_to_dataset(obj, coords=coords, dims=dims)
112
116
  elif isinstance(obj, np.ndarray):
113
117
  dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
@@ -122,7 +126,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
122
126
  "xarray dataarray",
123
127
  "xarray dataset",
124
128
  "dict",
125
- "pytree",
129
+ "pytree (if 'dm-tree' is installed)",
126
130
  "netcdf filename",
127
131
  "numpy array",
128
132
  "pystan fit",
@@ -102,6 +102,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
102
102
  def __init__(
103
103
  self,
104
104
  attrs: Union[None, Mapping[Any, Any]] = None,
105
+ warn_on_custom_groups: bool = False,
105
106
  **kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
106
107
  ) -> None:
107
108
  """Initialize InferenceData object from keyword xarray datasets.
@@ -110,6 +111,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
110
111
  ----------
111
112
  attrs : dict
112
113
  sets global attribute for InferenceData object.
114
+ warn_on_custom_groups : bool, default False
115
+ Emit a warning when custom groups are present in the InferenceData.
116
+ "custom group" means any group whose name isn't defined in :ref:`schema`
113
117
  kwargs :
114
118
  Keyword arguments of xarray datasets
115
119
 
@@ -154,9 +158,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
154
158
  for key in kwargs:
155
159
  if key not in SUPPORTED_GROUPS_ALL:
156
160
  key_list.append(key)
157
- warnings.warn(
158
- f"{key} group is not defined in the InferenceData scheme", UserWarning
159
- )
161
+ if warn_on_custom_groups:
162
+ warnings.warn(
163
+ f"{key} group is not defined in the InferenceData scheme", UserWarning
164
+ )
160
165
  for key in key_list:
161
166
  dataset = kwargs[key]
162
167
  dataset_warmup = None
@@ -266,6 +271,14 @@ class InferenceData(Mapping[str, xr.Dataset]):
266
271
  raise KeyError(key)
267
272
  return getattr(self, key)
268
273
 
274
+ def __setitem__(self, key: str, value: xr.Dataset):
275
+ """Set item by key and update group list accordingly."""
276
+ if key.startswith(WARMUP_TAG):
277
+ self._groups_warmup.append(key)
278
+ else:
279
+ self._groups.append(key)
280
+ setattr(self, key, value)
281
+
269
282
  def groups(self) -> List[str]:
270
283
  """Return all groups present in InferenceData object."""
271
284
  return self._groups_all
@@ -1459,7 +1472,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
1459
1472
  else:
1460
1473
  return out
1461
1474
 
1462
- def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
1475
+ def add_groups(
1476
+ self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
1477
+ ):
1463
1478
  """Add new groups to InferenceData object.
1464
1479
 
1465
1480
  Parameters
@@ -1471,6 +1486,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
1471
1486
  dims : dict of {str : list of str}, optional
1472
1487
  Dimensions of each variable. The keys are variable names, values are lists of
1473
1488
  coordinates.
1489
+ warn_on_custom_groups : bool, default False
1490
+ Emit a warning when custom groups are present in the InferenceData.
1491
+ "custom group" means any group whose name isn't defined in :ref:`schema`
1474
1492
  kwargs : dict, optional
1475
1493
  The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
1476
1494
 
@@ -1534,7 +1552,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1534
1552
  if repeated_groups:
1535
1553
  raise ValueError(f"{repeated_groups} group(s) already exists.")
1536
1554
  for group, dataset in group_dict.items():
1537
- if group not in SUPPORTED_GROUPS_ALL:
1555
+ if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
1538
1556
  warnings.warn(
1539
1557
  f"The group {group} is not defined in the InferenceData scheme",
1540
1558
  UserWarning,
@@ -1589,7 +1607,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1589
1607
  else:
1590
1608
  self._groups.append(group)
1591
1609
 
1592
- def extend(self, other, join="left"):
1610
+ def extend(self, other, join="left", warn_on_custom_groups=False):
1593
1611
  """Extend InferenceData with groups from another InferenceData.
1594
1612
 
1595
1613
  Parameters
@@ -1600,6 +1618,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
1600
1618
  Defines how the two decide which group to keep when the same group is
1601
1619
  present in both objects. 'left' will discard the group in ``other`` whereas 'right'
1602
1620
  will keep the group in ``other`` and discard the one in ``self``.
1621
+ warn_on_custom_groups : bool, default False
1622
+ Emit a warning when custom groups are present in the InferenceData.
1623
+ "custom group" means any group whose name isn't defined in :ref:`schema`
1603
1624
 
1604
1625
  Examples
1605
1626
  --------
@@ -1643,7 +1664,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
1643
1664
  for group in other._groups_all: # pylint: disable=protected-access
1644
1665
  if hasattr(self, group) and join == "left":
1645
1666
  continue
1646
- if group not in SUPPORTED_GROUPS_ALL:
1667
+ if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
1647
1668
  warnings.warn(
1648
1669
  f"{group} group is not defined in the InferenceData scheme", UserWarning
1649
1670
  )
@@ -213,10 +213,11 @@ def _copy_docstring(lib, function):
213
213
 
214
214
 
215
215
  # TODO: try copying substitutions too, or autoreplace them ourselves
216
- output_notebook.__doc__ += "\n\n" + _copy_docstring("bokeh.plotting", "output_notebook").replace(
217
- "|save|", "save"
218
- ).replace("|show|", "show")
219
- output_file.__doc__ += "\n\n" + _copy_docstring("bokeh.plotting", "output_file").replace(
220
- "|save|", "save"
221
- ).replace("|show|", "show")
222
- ColumnDataSource.__doc__ += "\n\n" + _copy_docstring("bokeh.models", "ColumnDataSource")
216
+ if output_notebook.__doc__ is not None: # if run with python -OO, __doc__ is stripped
217
+ output_notebook.__doc__ += "\n\n" + _copy_docstring(
218
+ "bokeh.plotting", "output_notebook"
219
+ ).replace("|save|", "save").replace("|show|", "show")
220
+ output_file.__doc__ += "\n\n" + _copy_docstring("bokeh.plotting", "output_file").replace(
221
+ "|save|", "save"
222
+ ).replace("|show|", "show")
223
+ ColumnDataSource.__doc__ += "\n\n" + _copy_docstring("bokeh.models", "ColumnDataSource")
@@ -41,6 +41,7 @@ def plot_bpv(
41
41
  plot_ref_kwargs,
42
42
  backend_kwargs,
43
43
  show,
44
+ smoothing,
44
45
  ):
45
46
  """Bokeh bpv plot."""
46
47
  if backend_kwargs is None:
@@ -90,6 +91,9 @@ def plot_bpv(
90
91
  obs_vals = obs_vals.flatten()
91
92
  pp_vals = pp_vals.reshape(total_pp_samples, -1)
92
93
 
94
+ if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
95
+ obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
96
+
93
97
  if kind == "p_value":
94
98
  tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
95
99
  x_s, tstat_pit_dens = kde(tstat_pit)
@@ -115,9 +119,6 @@ def plot_bpv(
115
119
  )
116
120
 
117
121
  elif kind == "u_value":
118
- if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
119
- obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
120
-
121
122
  tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
122
123
  x_s, tstat_pit_dens = kde(tstat_pit)
123
124
  ax_i.line(x_s, tstat_pit_dens, color=color)
@@ -225,7 +225,11 @@ def _d_helper(
225
225
 
226
226
  if point_estimate is not None:
227
227
  est = calculate_point_estimate(point_estimate, vec, bw, circular)
228
- plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize))
228
+ plotted.append(
229
+ ax.scatter(
230
+ est, 0, marker="circle", fill_color=color, line_color="black", size=markersize
231
+ )
232
+ )
229
233
 
230
234
  _title = Title()
231
235
  _title.text = vname
@@ -47,7 +47,10 @@ def plot_dot(
47
47
 
48
48
  if plot_kwargs is None:
49
49
  plot_kwargs = {}
50
- plot_kwargs.setdefault("color", dotcolor)
50
+ else:
51
+ plot_kwargs = plot_kwargs.copy()
52
+ plot_kwargs.setdefault("color", dotcolor)
53
+ plot_kwargs.setdefault("marker", "circle")
51
54
 
52
55
  if linewidth is None:
53
56
  linewidth = auto_linewidth
@@ -95,7 +98,7 @@ def plot_dot(
95
98
  stack_locs, stack_count = wilkinson_algorithm(values, binwidth)
96
99
  x, y = layout_stacks(stack_locs, stack_count, binwidth, stackratio, rotated)
97
100
 
98
- ax.circle(x, y, radius=dotsize * (binwidth / 2), **plot_kwargs, radius_dimension="y")
101
+ ax.scatter(x, y, radius=dotsize * (binwidth / 2), **plot_kwargs, radius_dimension="y")
99
102
  if rotated:
100
103
  ax.xaxis.major_tick_line_color = None
101
104
  ax.xaxis.minor_tick_line_color = None
@@ -73,12 +73,14 @@ def plot_ess(
73
73
  for (var_name, selection, isel, x), ax_ in zip(
74
74
  plotters, (item for item in ax.flatten() if item is not None)
75
75
  ):
76
- bulk_points = ax_.circle(np.asarray(xdata), np.asarray(x), size=6)
76
+ bulk_points = ax_.scatter(np.asarray(xdata), np.asarray(x), marker="circle", size=6)
77
77
  if kind == "evolution":
78
78
  bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
79
79
  ess_tail = ess_tail_dataset[var_name].sel(**selection)
80
80
  tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange")
81
- tail_line = ax_.circle(np.asarray(xdata), np.asarray(ess_tail), size=6, color="orange")
81
+ tail_line = ax_.scatter(
82
+ np.asarray(xdata), np.asarray(ess_tail), marker="circle", size=6, color="orange"
83
+ )
82
84
  elif rug:
83
85
  if rug_kwargs is None:
84
86
  rug_kwargs = {}
@@ -535,9 +535,10 @@ class PlotHandler:
535
535
  )
536
536
  )
537
537
  plotted[model_name].append(
538
- ax.circle(
538
+ ax.scatter(
539
539
  x=values[mid],
540
540
  y=y,
541
+ marker="circle",
541
542
  size=markersize * 0.75,
542
543
  fill_color=color,
543
544
  )
@@ -555,9 +556,10 @@ class PlotHandler:
555
556
  for y, ess, color, model_name in plotter.ess():
556
557
  if ess is not None:
557
558
  plotted[model_name].append(
558
- ax.circle(
559
+ ax.scatter(
559
560
  x=ess,
560
561
  y=y,
562
+ marker="circle",
561
563
  fill_color=color,
562
564
  size=markersize,
563
565
  line_color="black",
@@ -582,8 +584,13 @@ class PlotHandler:
582
584
  for y, r_hat, color, model_name in plotter.r_hat():
583
585
  if r_hat is not None:
584
586
  plotted[model_name].append(
585
- ax.circle(
586
- x=r_hat, y=y, fill_color=color, size=markersize, line_color="black"
587
+ ax.scatter(
588
+ x=r_hat,
589
+ y=y,
590
+ marker="circle",
591
+ fill_color=color,
592
+ size=markersize,
593
+ line_color="black",
587
594
  )
588
595
  )
589
596
  ax.x_range._property_values["start"] = 0.9 # pylint: disable=protected-access
@@ -94,9 +94,10 @@ def plot_khat(
94
94
 
95
95
  if not isinstance(rgba_c, str) and isinstance(rgba_c, Iterable):
96
96
  for idx, (alpha, rgba_c_) in enumerate(zip(alphas, rgba_c)):
97
- ax.cross(
97
+ ax.scatter(
98
98
  xdata[idx],
99
99
  khats[idx],
100
+ marker="cross",
100
101
  line_color=rgba_c_,
101
102
  fill_color=rgba_c_,
102
103
  line_alpha=alpha,
@@ -104,9 +105,10 @@ def plot_khat(
104
105
  size=10,
105
106
  )
106
107
  else:
107
- ax.cross(
108
+ ax.scatter(
108
109
  xdata,
109
110
  khats,
111
+ marker="cross",
110
112
  line_color=rgba_c,
111
113
  fill_color=rgba_c,
112
114
  size=10,
@@ -51,12 +51,18 @@ def plot_lm(
51
51
 
52
52
  if y_kwargs is None:
53
53
  y_kwargs = {}
54
+ else:
55
+ y_kwargs = y_kwargs.copy()
56
+ y_kwargs.setdefault("marker", "circle")
54
57
  y_kwargs.setdefault("fill_color", "red")
55
58
  y_kwargs.setdefault("line_width", 0)
56
59
  y_kwargs.setdefault("size", 3)
57
60
 
58
61
  if y_hat_plot_kwargs is None:
59
62
  y_hat_plot_kwargs = {}
63
+ else:
64
+ y_hat_plot_kwargs = y_hat_plot_kwargs.copy()
65
+ y_hat_plot_kwargs.setdefault("marker", "circle")
60
66
  y_hat_plot_kwargs.setdefault("fill_color", "orange")
61
67
  y_hat_plot_kwargs.setdefault("line_width", 0)
62
68
 
@@ -84,7 +90,7 @@ def plot_lm(
84
90
  _, _, _, y_plotters = y[i]
85
91
  _, _, _, x_plotters = x[i]
86
92
  legend_it = []
87
- observed_legend = ax_i.circle(x_plotters, y_plotters, **y_kwargs)
93
+ observed_legend = ax_i.scatter(x_plotters, y_plotters, **y_kwargs)
88
94
  legend_it.append(("Observed", [observed_legend]))
89
95
 
90
96
  if y_hat is not None:
@@ -98,14 +104,14 @@ def plot_lm(
98
104
  x_plotters_jitter = x_plotters + np.random.uniform(
99
105
  low=-scale_high, high=scale_high, size=len(x_plotters)
100
106
  )
101
- posterior_circle = ax_i.circle(
107
+ posterior_circle = ax_i.scatter(
102
108
  x_plotters_jitter,
103
109
  y_hat_plotters[..., j],
104
110
  alpha=0.2,
105
111
  **y_hat_plot_kwargs,
106
112
  )
107
113
  else:
108
- posterior_circle = ax_i.circle(
114
+ posterior_circle = ax_i.scatter(
109
115
  x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs
110
116
  )
111
117
  posterior_legend.append(posterior_circle)
@@ -71,13 +71,13 @@ def plot_mcse(
71
71
  values = data[var_name].sel(**selection).values.flatten()
72
72
  if errorbar:
73
73
  quantile_values = _quantile(values, probs)
74
- ax_.dash(probs, quantile_values)
74
+ ax_.scatter(probs, quantile_values, marker="dash")
75
75
  ax_.multi_line(
76
76
  list(zip(probs, probs)),
77
77
  [(quant - err, quant + err) for quant, err in zip(quantile_values, x)],
78
78
  )
79
79
  else:
80
- ax_.circle(probs, x)
80
+ ax_.scatter(probs, x, marker="circle")
81
81
  if extra_methods:
82
82
  mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
83
83
  sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
@@ -121,12 +121,14 @@ def plot_pair(
121
121
  )
122
122
 
123
123
  reference_values_kwargs = _init_kwargs_dict(reference_values_kwargs)
124
+ reference_values_kwargs.setdefault("marker", "circle")
124
125
  reference_values_kwargs.setdefault("line_color", "black")
125
126
  reference_values_kwargs.setdefault("fill_color", vectorized_to_hex("C2"))
126
127
  reference_values_kwargs.setdefault("line_width", 1)
127
128
  reference_values_kwargs.setdefault("size", 10)
128
129
 
129
130
  divergences_kwargs = _init_kwargs_dict(divergences_kwargs)
131
+ divergences_kwargs.setdefault("marker", "circle")
130
132
  divergences_kwargs.setdefault("line_color", "black")
131
133
  divergences_kwargs.setdefault("fill_color", vectorized_to_hex("C1"))
132
134
  divergences_kwargs.setdefault("line_width", 1)
@@ -155,6 +157,7 @@ def plot_pair(
155
157
  )
156
158
 
157
159
  point_estimate_marker_kwargs = _init_kwargs_dict(point_estimate_marker_kwargs)
160
+ point_estimate_marker_kwargs.setdefault("marker", "square")
158
161
  point_estimate_marker_kwargs.setdefault("size", markersize)
159
162
  point_estimate_marker_kwargs.setdefault("color", "black")
160
163
  point_estimate_kwargs.setdefault("line_color", "black")
@@ -265,9 +268,11 @@ def plot_pair(
265
268
  elif j + marginals_offset > i:
266
269
  if "scatter" in kind:
267
270
  if divergences:
268
- ax[j, i].circle(var1, var2, source=source, view=source_nondiv)
271
+ ax[j, i].scatter(
272
+ var1, var2, marker="circle", source=source, view=source_nondiv
273
+ )
269
274
  else:
270
- ax[j, i].circle(var1, var2, source=source)
275
+ ax[j, i].scatter(var1, var2, marker="circle", source=source)
271
276
 
272
277
  if "kde" in kind:
273
278
  var1_kde = plotters[i][-1].flatten()
@@ -293,7 +298,7 @@ def plot_pair(
293
298
  )
294
299
 
295
300
  if divergences:
296
- ax[j, i].circle(
301
+ ax[j, i].scatter(
297
302
  var1,
298
303
  var2,
299
304
  source=source,
@@ -306,7 +311,7 @@ def plot_pair(
306
311
  var2_pe = plotters[j][-1].flatten()
307
312
  pe_x = calculate_point_estimate(point_estimate, var1_pe)
308
313
  pe_y = calculate_point_estimate(point_estimate, var2_pe)
309
- ax[j, i].square(pe_x, pe_y, **point_estimate_marker_kwargs)
314
+ ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
310
315
 
311
316
  ax_hline = Span(
312
317
  location=pe_y,
@@ -344,7 +349,7 @@ def plot_pair(
344
349
  x = reference_values_copy[flat_var_names[j + marginals_offset]]
345
350
  y = reference_values_copy[flat_var_names[i]]
346
351
  if x and y:
347
- ax[j, i].circle(y, x, **reference_values_kwargs)
352
+ ax[j, i].scatter(y, x, **reference_values_kwargs)
348
353
  ax[j, i].xaxis.axis_label = flat_var_names[i]
349
354
  ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
350
355
 
@@ -313,9 +313,10 @@ def plot_ppc(
313
313
  obs_yvals += np.random.uniform(
314
314
  low=scale_low, high=scale_high, size=len(obs_vals)
315
315
  )
316
- glyph = ax_i.circle(
316
+ glyph = ax_i.scatter(
317
317
  obs_vals,
318
318
  obs_yvals,
319
+ marker="circle",
319
320
  line_color=colors[1],
320
321
  fill_color=colors[1],
321
322
  size=markersize,
@@ -49,6 +49,7 @@ def plot_rank(
49
49
 
50
50
  if marker_vlines_kwargs is None:
51
51
  marker_vlines_kwargs = {}
52
+ marker_vlines_kwargs.setdefault("marker", "circle")
52
53
 
53
54
  if backend_kwargs is None:
54
55
  backend_kwargs = {}
@@ -109,7 +110,7 @@ def plot_rank(
109
110
  elif kind == "vlines":
110
111
  ymin = np.full(len(all_counts), all_counts.mean())
111
112
  for idx, counts in enumerate(all_counts):
112
- ax.circle(
113
+ ax.scatter(
113
114
  bin_ary,
114
115
  counts,
115
116
  fill_color=colors[idx],
@@ -385,9 +385,10 @@ def _plot_chains_bokeh(
385
385
  **dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx),
386
386
  )
387
387
  if marker:
388
- ax_trace.circle(
388
+ ax_trace.scatter(
389
389
  x=x_name,
390
390
  y=y_name,
391
+ marker="circle",
391
392
  source=cds,
392
393
  radius=0.30,
393
394
  alpha=0.5,
@@ -81,9 +81,10 @@ def plot_violin(
81
81
  [0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
82
82
  )
83
83
  ax_.line([0, 0], hdi_probs, line_width=linewidth, line_color="black", line_cap="round")
84
- ax_.circle(
84
+ ax_.scatter(
85
85
  0,
86
86
  per[-1],
87
+ marker="circle",
87
88
  line_color="white",
88
89
  fill_color="white",
89
90
  size=linewidth * 1.5,
@@ -38,6 +38,7 @@ def plot_bpv(
38
38
  plot_ref_kwargs,
39
39
  backend_kwargs,
40
40
  show,
41
+ smoothing,
41
42
  ):
42
43
  """Matplotlib bpv plot."""
43
44
  if backend_kwargs is None:
@@ -87,7 +88,7 @@ def plot_bpv(
87
88
  obs_vals = obs_vals.flatten()
88
89
  pp_vals = pp_vals.reshape(total_pp_samples, -1)
89
90
 
90
- if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
91
+ if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
91
92
  obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
92
93
 
93
94
  if kind == "p_value":
arviz/plots/bfplot.py CHANGED
@@ -2,11 +2,9 @@
2
2
  # pylint: disable=unbalanced-tuple-unpacking
3
3
  import logging
4
4
 
5
- from numpy import interp
6
-
7
5
  from ..data.utils import extract
8
6
  from .plot_utils import get_plotting_function
9
- from ..stats.density_utils import _kde_linear
7
+ from ..stats import bayes_factor
10
8
 
11
9
  _log = logging.getLogger(__name__)
12
10
 
@@ -94,32 +92,17 @@ def plot_bf(
94
92
  >>> az.plot_bf(idata, var_name="a", ref_val=0)
95
93
 
96
94
  """
97
- posterior = extract(idata, var_names=var_name).values
98
-
99
- if ref_val > posterior.max() or ref_val < posterior.min():
100
- _log.warning(
101
- "The reference value is outside of the posterior. "
102
- "This translate into infinite support for H1, which is most likely an overstatement."
103
- )
104
-
105
- if posterior.ndim > 1:
106
- _log.warning("Posterior distribution has {posterior.ndim} dimensions")
107
95
 
108
96
  if prior is None:
109
97
  prior = extract(idata, var_names=var_name, group="prior").values
110
98
 
111
- if posterior.dtype.kind == "f":
112
- posterior_grid, posterior_pdf = _kde_linear(posterior)
113
- prior_grid, prior_pdf = _kde_linear(prior)
114
- posterior_at_ref_val = interp(ref_val, posterior_grid, posterior_pdf)
115
- prior_at_ref_val = interp(ref_val, prior_grid, prior_pdf)
116
-
117
- elif posterior.dtype.kind == "i":
118
- posterior_at_ref_val = (posterior == ref_val).mean()
119
- prior_at_ref_val = (prior == ref_val).mean()
99
+ bf, p_at_ref_val = bayes_factor(
100
+ idata, var_name, prior=prior, ref_val=ref_val, return_ref_vals=True
101
+ )
102
+ bf_10 = bf["BF10"]
103
+ bf_01 = bf["BF01"]
120
104
 
121
- bf_10 = prior_at_ref_val / posterior_at_ref_val
122
- bf_01 = 1 / bf_10
105
+ posterior = extract(idata, var_names=var_name)
123
106
 
124
107
  bfplot_kwargs = dict(
125
108
  ax=ax,
@@ -128,8 +111,8 @@ def plot_bf(
128
111
  prior=prior,
129
112
  posterior=posterior,
130
113
  ref_val=ref_val,
131
- prior_at_ref_val=prior_at_ref_val,
132
- posterior_at_ref_val=posterior_at_ref_val,
114
+ prior_at_ref_val=p_at_ref_val["prior"],
115
+ posterior_at_ref_val=p_at_ref_val["posterior"],
133
116
  var_name=var_name,
134
117
  colors=colors,
135
118
  figsize=figsize,