arviz 0.19.0__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/base.py +20 -9
- arviz/data/converters.py +7 -3
- arviz/data/inference_data.py +28 -7
- arviz/plots/backends/__init__.py +8 -7
- arviz/plots/backends/bokeh/bpvplot.py +4 -3
- arviz/plots/backends/bokeh/densityplot.py +5 -1
- arviz/plots/backends/bokeh/dotplot.py +5 -2
- arviz/plots/backends/bokeh/essplot.py +4 -2
- arviz/plots/backends/bokeh/forestplot.py +11 -4
- arviz/plots/backends/bokeh/khatplot.py +4 -2
- arviz/plots/backends/bokeh/lmplot.py +9 -3
- arviz/plots/backends/bokeh/mcseplot.py +2 -2
- arviz/plots/backends/bokeh/pairplot.py +10 -5
- arviz/plots/backends/bokeh/ppcplot.py +2 -1
- arviz/plots/backends/bokeh/rankplot.py +2 -1
- arviz/plots/backends/bokeh/traceplot.py +2 -1
- arviz/plots/backends/bokeh/violinplot.py +2 -1
- arviz/plots/backends/matplotlib/bpvplot.py +2 -1
- arviz/plots/bfplot.py +9 -26
- arviz/plots/bpvplot.py +10 -1
- arviz/plots/compareplot.py +4 -4
- arviz/plots/ecdfplot.py +16 -8
- arviz/plots/forestplot.py +2 -2
- arviz/plots/hdiplot.py +5 -0
- arviz/plots/kdeplot.py +9 -2
- arviz/plots/plot_utils.py +5 -3
- arviz/preview.py +36 -5
- arviz/stats/__init__.py +1 -0
- arviz/stats/diagnostics.py +18 -14
- arviz/stats/ecdf_utils.py +157 -2
- arviz/stats/stats.py +99 -7
- arviz/tests/base_tests/test_data.py +41 -7
- arviz/tests/base_tests/test_diagnostics.py +5 -4
- arviz/tests/base_tests/test_plots_matplotlib.py +32 -13
- arviz/tests/base_tests/test_stats.py +11 -0
- arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
- arviz/utils.py +4 -0
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/METADATA +22 -22
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/RECORD +43 -43
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/WHEEL +1 -1
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/LICENSE +0 -0
- {arviz-0.19.0.dist-info → arviz-0.21.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py
CHANGED
arviz/data/base.py
CHANGED
|
@@ -9,9 +9,13 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
|
-
import tree
|
|
13
12
|
import xarray as xr
|
|
14
13
|
|
|
14
|
+
try:
|
|
15
|
+
import tree
|
|
16
|
+
except ImportError:
|
|
17
|
+
tree = None
|
|
18
|
+
|
|
15
19
|
try:
|
|
16
20
|
import ujson as json
|
|
17
21
|
except ImportError:
|
|
@@ -89,6 +93,9 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|
|
89
93
|
input_tree.
|
|
90
94
|
"""
|
|
91
95
|
# pylint: disable=protected-access
|
|
96
|
+
if tree is None:
|
|
97
|
+
raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
|
|
98
|
+
|
|
92
99
|
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
|
|
93
100
|
isinstance(shallow_tree, tree.collections_abc.Mapping)
|
|
94
101
|
or tree._is_namedtuple(shallow_tree)
|
|
@@ -194,10 +201,10 @@ def generate_dims_coords(
|
|
|
194
201
|
for i, dim_len in enumerate(shape):
|
|
195
202
|
idx = i + len([dim for dim in default_dims if dim in dims])
|
|
196
203
|
if len(dims) < idx + 1:
|
|
197
|
-
dim_name = f"{var_name}_dim_{
|
|
204
|
+
dim_name = f"{var_name}_dim_{i}"
|
|
198
205
|
dims.append(dim_name)
|
|
199
206
|
elif dims[idx] is None:
|
|
200
|
-
dim_name = f"{var_name}_dim_{
|
|
207
|
+
dim_name = f"{var_name}_dim_{i}"
|
|
201
208
|
dims[idx] = dim_name
|
|
202
209
|
dim_name = dims[idx]
|
|
203
210
|
if dim_name not in coords:
|
|
@@ -299,7 +306,7 @@ def numpy_to_data_array(
|
|
|
299
306
|
return xr.DataArray(ary, coords=coords, dims=dims)
|
|
300
307
|
|
|
301
308
|
|
|
302
|
-
def
|
|
309
|
+
def dict_to_dataset(
|
|
303
310
|
data,
|
|
304
311
|
*,
|
|
305
312
|
attrs=None,
|
|
@@ -312,6 +319,8 @@ def pytree_to_dataset(
|
|
|
312
319
|
):
|
|
313
320
|
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
|
|
314
321
|
|
|
322
|
+
ArviZ itself supports conversion of flat dictionaries.
|
|
323
|
+
Suport for pytrees requires ``dm-tree`` which is an optional dependency.
|
|
315
324
|
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
|
|
316
325
|
this inclues at least dictionaries and tuple types.
|
|
317
326
|
|
|
@@ -386,10 +395,12 @@ def pytree_to_dataset(
|
|
|
386
395
|
"""
|
|
387
396
|
if dims is None:
|
|
388
397
|
dims = {}
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
398
|
+
|
|
399
|
+
if tree is not None:
|
|
400
|
+
try:
|
|
401
|
+
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
|
|
402
|
+
except TypeError: # probably unsortable keys -- the function will still work if
|
|
403
|
+
pass # it is an honest dictionary.
|
|
393
404
|
|
|
394
405
|
data_vars = {
|
|
395
406
|
key: numpy_to_data_array(
|
|
@@ -406,7 +417,7 @@ def pytree_to_dataset(
|
|
|
406
417
|
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
|
|
407
418
|
|
|
408
419
|
|
|
409
|
-
|
|
420
|
+
pytree_to_dataset = dict_to_dataset
|
|
410
421
|
|
|
411
422
|
|
|
412
423
|
def make_attrs(attrs=None, library=None):
|
arviz/data/converters.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
"""High level conversion functions."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
import tree
|
|
5
4
|
import xarray as xr
|
|
6
5
|
|
|
6
|
+
try:
|
|
7
|
+
from tree import is_nested
|
|
8
|
+
except ImportError:
|
|
9
|
+
is_nested = lambda obj: False
|
|
10
|
+
|
|
7
11
|
from .base import dict_to_dataset
|
|
8
12
|
from .inference_data import InferenceData
|
|
9
13
|
from .io_beanmachine import from_beanmachine
|
|
@@ -107,7 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
107
111
|
dataset = obj.to_dataset()
|
|
108
112
|
elif isinstance(obj, dict):
|
|
109
113
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
110
|
-
elif
|
|
114
|
+
elif is_nested(obj) and not isinstance(obj, (list, tuple)):
|
|
111
115
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
112
116
|
elif isinstance(obj, np.ndarray):
|
|
113
117
|
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
@@ -122,7 +126,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
122
126
|
"xarray dataarray",
|
|
123
127
|
"xarray dataset",
|
|
124
128
|
"dict",
|
|
125
|
-
"pytree",
|
|
129
|
+
"pytree (if 'dm-tree' is installed)",
|
|
126
130
|
"netcdf filename",
|
|
127
131
|
"numpy array",
|
|
128
132
|
"pystan fit",
|
arviz/data/inference_data.py
CHANGED
|
@@ -102,6 +102,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
102
102
|
def __init__(
|
|
103
103
|
self,
|
|
104
104
|
attrs: Union[None, Mapping[Any, Any]] = None,
|
|
105
|
+
warn_on_custom_groups: bool = False,
|
|
105
106
|
**kwargs: Union[xr.Dataset, List[xr.Dataset], Tuple[xr.Dataset, xr.Dataset]],
|
|
106
107
|
) -> None:
|
|
107
108
|
"""Initialize InferenceData object from keyword xarray datasets.
|
|
@@ -110,6 +111,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
110
111
|
----------
|
|
111
112
|
attrs : dict
|
|
112
113
|
sets global attribute for InferenceData object.
|
|
114
|
+
warn_on_custom_groups : bool, default False
|
|
115
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
116
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
113
117
|
kwargs :
|
|
114
118
|
Keyword arguments of xarray datasets
|
|
115
119
|
|
|
@@ -154,9 +158,10 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
154
158
|
for key in kwargs:
|
|
155
159
|
if key not in SUPPORTED_GROUPS_ALL:
|
|
156
160
|
key_list.append(key)
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
161
|
+
if warn_on_custom_groups:
|
|
162
|
+
warnings.warn(
|
|
163
|
+
f"{key} group is not defined in the InferenceData scheme", UserWarning
|
|
164
|
+
)
|
|
160
165
|
for key in key_list:
|
|
161
166
|
dataset = kwargs[key]
|
|
162
167
|
dataset_warmup = None
|
|
@@ -266,6 +271,14 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
266
271
|
raise KeyError(key)
|
|
267
272
|
return getattr(self, key)
|
|
268
273
|
|
|
274
|
+
def __setitem__(self, key: str, value: xr.Dataset):
|
|
275
|
+
"""Set item by key and update group list accordingly."""
|
|
276
|
+
if key.startswith(WARMUP_TAG):
|
|
277
|
+
self._groups_warmup.append(key)
|
|
278
|
+
else:
|
|
279
|
+
self._groups.append(key)
|
|
280
|
+
setattr(self, key, value)
|
|
281
|
+
|
|
269
282
|
def groups(self) -> List[str]:
|
|
270
283
|
"""Return all groups present in InferenceData object."""
|
|
271
284
|
return self._groups_all
|
|
@@ -1459,7 +1472,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1459
1472
|
else:
|
|
1460
1473
|
return out
|
|
1461
1474
|
|
|
1462
|
-
def add_groups(
|
|
1475
|
+
def add_groups(
|
|
1476
|
+
self, group_dict=None, coords=None, dims=None, warn_on_custom_groups=False, **kwargs
|
|
1477
|
+
):
|
|
1463
1478
|
"""Add new groups to InferenceData object.
|
|
1464
1479
|
|
|
1465
1480
|
Parameters
|
|
@@ -1471,6 +1486,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1471
1486
|
dims : dict of {str : list of str}, optional
|
|
1472
1487
|
Dimensions of each variable. The keys are variable names, values are lists of
|
|
1473
1488
|
coordinates.
|
|
1489
|
+
warn_on_custom_groups : bool, default False
|
|
1490
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
1491
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1474
1492
|
kwargs : dict, optional
|
|
1475
1493
|
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
|
|
1476
1494
|
|
|
@@ -1534,7 +1552,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1534
1552
|
if repeated_groups:
|
|
1535
1553
|
raise ValueError(f"{repeated_groups} group(s) already exists.")
|
|
1536
1554
|
for group, dataset in group_dict.items():
|
|
1537
|
-
if group not in SUPPORTED_GROUPS_ALL:
|
|
1555
|
+
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1538
1556
|
warnings.warn(
|
|
1539
1557
|
f"The group {group} is not defined in the InferenceData scheme",
|
|
1540
1558
|
UserWarning,
|
|
@@ -1589,7 +1607,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1589
1607
|
else:
|
|
1590
1608
|
self._groups.append(group)
|
|
1591
1609
|
|
|
1592
|
-
def extend(self, other, join="left"):
|
|
1610
|
+
def extend(self, other, join="left", warn_on_custom_groups=False):
|
|
1593
1611
|
"""Extend InferenceData with groups from another InferenceData.
|
|
1594
1612
|
|
|
1595
1613
|
Parameters
|
|
@@ -1600,6 +1618,9 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1600
1618
|
Defines how the two decide which group to keep when the same group is
|
|
1601
1619
|
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
|
|
1602
1620
|
will keep the group in ``other`` and discard the one in ``self``.
|
|
1621
|
+
warn_on_custom_groups : bool, default False
|
|
1622
|
+
Emit a warning when custom groups are present in the InferenceData.
|
|
1623
|
+
"custom group" means any group whose name isn't defined in :ref:`schema`
|
|
1603
1624
|
|
|
1604
1625
|
Examples
|
|
1605
1626
|
--------
|
|
@@ -1643,7 +1664,7 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1643
1664
|
for group in other._groups_all: # pylint: disable=protected-access
|
|
1644
1665
|
if hasattr(self, group) and join == "left":
|
|
1645
1666
|
continue
|
|
1646
|
-
if group not in SUPPORTED_GROUPS_ALL:
|
|
1667
|
+
if warn_on_custom_groups and group not in SUPPORTED_GROUPS_ALL:
|
|
1647
1668
|
warnings.warn(
|
|
1648
1669
|
f"{group} group is not defined in the InferenceData scheme", UserWarning
|
|
1649
1670
|
)
|
arviz/plots/backends/__init__.py
CHANGED
|
@@ -213,10 +213,11 @@ def _copy_docstring(lib, function):
|
|
|
213
213
|
|
|
214
214
|
|
|
215
215
|
# TODO: try copying substitutions too, or autoreplace them ourselves
|
|
216
|
-
output_notebook.__doc__
|
|
217
|
-
"
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
"
|
|
221
|
-
|
|
222
|
-
|
|
216
|
+
if output_notebook.__doc__ is not None: # if run with python -OO, __doc__ is stripped
|
|
217
|
+
output_notebook.__doc__ += "\n\n" + _copy_docstring(
|
|
218
|
+
"bokeh.plotting", "output_notebook"
|
|
219
|
+
).replace("|save|", "save").replace("|show|", "show")
|
|
220
|
+
output_file.__doc__ += "\n\n" + _copy_docstring("bokeh.plotting", "output_file").replace(
|
|
221
|
+
"|save|", "save"
|
|
222
|
+
).replace("|show|", "show")
|
|
223
|
+
ColumnDataSource.__doc__ += "\n\n" + _copy_docstring("bokeh.models", "ColumnDataSource")
|
|
@@ -41,6 +41,7 @@ def plot_bpv(
|
|
|
41
41
|
plot_ref_kwargs,
|
|
42
42
|
backend_kwargs,
|
|
43
43
|
show,
|
|
44
|
+
smoothing,
|
|
44
45
|
):
|
|
45
46
|
"""Bokeh bpv plot."""
|
|
46
47
|
if backend_kwargs is None:
|
|
@@ -90,6 +91,9 @@ def plot_bpv(
|
|
|
90
91
|
obs_vals = obs_vals.flatten()
|
|
91
92
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
|
92
93
|
|
|
94
|
+
if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
|
|
95
|
+
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
96
|
+
|
|
93
97
|
if kind == "p_value":
|
|
94
98
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
|
|
95
99
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
@@ -115,9 +119,6 @@ def plot_bpv(
|
|
|
115
119
|
)
|
|
116
120
|
|
|
117
121
|
elif kind == "u_value":
|
|
118
|
-
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
119
|
-
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
120
|
-
|
|
121
122
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
|
|
122
123
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
|
123
124
|
ax_i.line(x_s, tstat_pit_dens, color=color)
|
|
@@ -225,7 +225,11 @@ def _d_helper(
|
|
|
225
225
|
|
|
226
226
|
if point_estimate is not None:
|
|
227
227
|
est = calculate_point_estimate(point_estimate, vec, bw, circular)
|
|
228
|
-
plotted.append(
|
|
228
|
+
plotted.append(
|
|
229
|
+
ax.scatter(
|
|
230
|
+
est, 0, marker="circle", fill_color=color, line_color="black", size=markersize
|
|
231
|
+
)
|
|
232
|
+
)
|
|
229
233
|
|
|
230
234
|
_title = Title()
|
|
231
235
|
_title.text = vname
|
|
@@ -47,7 +47,10 @@ def plot_dot(
|
|
|
47
47
|
|
|
48
48
|
if plot_kwargs is None:
|
|
49
49
|
plot_kwargs = {}
|
|
50
|
-
|
|
50
|
+
else:
|
|
51
|
+
plot_kwargs = plot_kwargs.copy()
|
|
52
|
+
plot_kwargs.setdefault("color", dotcolor)
|
|
53
|
+
plot_kwargs.setdefault("marker", "circle")
|
|
51
54
|
|
|
52
55
|
if linewidth is None:
|
|
53
56
|
linewidth = auto_linewidth
|
|
@@ -95,7 +98,7 @@ def plot_dot(
|
|
|
95
98
|
stack_locs, stack_count = wilkinson_algorithm(values, binwidth)
|
|
96
99
|
x, y = layout_stacks(stack_locs, stack_count, binwidth, stackratio, rotated)
|
|
97
100
|
|
|
98
|
-
ax.
|
|
101
|
+
ax.scatter(x, y, radius=dotsize * (binwidth / 2), **plot_kwargs, radius_dimension="y")
|
|
99
102
|
if rotated:
|
|
100
103
|
ax.xaxis.major_tick_line_color = None
|
|
101
104
|
ax.xaxis.minor_tick_line_color = None
|
|
@@ -73,12 +73,14 @@ def plot_ess(
|
|
|
73
73
|
for (var_name, selection, isel, x), ax_ in zip(
|
|
74
74
|
plotters, (item for item in ax.flatten() if item is not None)
|
|
75
75
|
):
|
|
76
|
-
bulk_points = ax_.
|
|
76
|
+
bulk_points = ax_.scatter(np.asarray(xdata), np.asarray(x), marker="circle", size=6)
|
|
77
77
|
if kind == "evolution":
|
|
78
78
|
bulk_line = ax_.line(np.asarray(xdata), np.asarray(x))
|
|
79
79
|
ess_tail = ess_tail_dataset[var_name].sel(**selection)
|
|
80
80
|
tail_points = ax_.line(np.asarray(xdata), np.asarray(ess_tail), color="orange")
|
|
81
|
-
tail_line = ax_.
|
|
81
|
+
tail_line = ax_.scatter(
|
|
82
|
+
np.asarray(xdata), np.asarray(ess_tail), marker="circle", size=6, color="orange"
|
|
83
|
+
)
|
|
82
84
|
elif rug:
|
|
83
85
|
if rug_kwargs is None:
|
|
84
86
|
rug_kwargs = {}
|
|
@@ -535,9 +535,10 @@ class PlotHandler:
|
|
|
535
535
|
)
|
|
536
536
|
)
|
|
537
537
|
plotted[model_name].append(
|
|
538
|
-
ax.
|
|
538
|
+
ax.scatter(
|
|
539
539
|
x=values[mid],
|
|
540
540
|
y=y,
|
|
541
|
+
marker="circle",
|
|
541
542
|
size=markersize * 0.75,
|
|
542
543
|
fill_color=color,
|
|
543
544
|
)
|
|
@@ -555,9 +556,10 @@ class PlotHandler:
|
|
|
555
556
|
for y, ess, color, model_name in plotter.ess():
|
|
556
557
|
if ess is not None:
|
|
557
558
|
plotted[model_name].append(
|
|
558
|
-
ax.
|
|
559
|
+
ax.scatter(
|
|
559
560
|
x=ess,
|
|
560
561
|
y=y,
|
|
562
|
+
marker="circle",
|
|
561
563
|
fill_color=color,
|
|
562
564
|
size=markersize,
|
|
563
565
|
line_color="black",
|
|
@@ -582,8 +584,13 @@ class PlotHandler:
|
|
|
582
584
|
for y, r_hat, color, model_name in plotter.r_hat():
|
|
583
585
|
if r_hat is not None:
|
|
584
586
|
plotted[model_name].append(
|
|
585
|
-
ax.
|
|
586
|
-
x=r_hat,
|
|
587
|
+
ax.scatter(
|
|
588
|
+
x=r_hat,
|
|
589
|
+
y=y,
|
|
590
|
+
marker="circle",
|
|
591
|
+
fill_color=color,
|
|
592
|
+
size=markersize,
|
|
593
|
+
line_color="black",
|
|
587
594
|
)
|
|
588
595
|
)
|
|
589
596
|
ax.x_range._property_values["start"] = 0.9 # pylint: disable=protected-access
|
|
@@ -94,9 +94,10 @@ def plot_khat(
|
|
|
94
94
|
|
|
95
95
|
if not isinstance(rgba_c, str) and isinstance(rgba_c, Iterable):
|
|
96
96
|
for idx, (alpha, rgba_c_) in enumerate(zip(alphas, rgba_c)):
|
|
97
|
-
ax.
|
|
97
|
+
ax.scatter(
|
|
98
98
|
xdata[idx],
|
|
99
99
|
khats[idx],
|
|
100
|
+
marker="cross",
|
|
100
101
|
line_color=rgba_c_,
|
|
101
102
|
fill_color=rgba_c_,
|
|
102
103
|
line_alpha=alpha,
|
|
@@ -104,9 +105,10 @@ def plot_khat(
|
|
|
104
105
|
size=10,
|
|
105
106
|
)
|
|
106
107
|
else:
|
|
107
|
-
ax.
|
|
108
|
+
ax.scatter(
|
|
108
109
|
xdata,
|
|
109
110
|
khats,
|
|
111
|
+
marker="cross",
|
|
110
112
|
line_color=rgba_c,
|
|
111
113
|
fill_color=rgba_c,
|
|
112
114
|
size=10,
|
|
@@ -51,12 +51,18 @@ def plot_lm(
|
|
|
51
51
|
|
|
52
52
|
if y_kwargs is None:
|
|
53
53
|
y_kwargs = {}
|
|
54
|
+
else:
|
|
55
|
+
y_kwargs = y_kwargs.copy()
|
|
56
|
+
y_kwargs.setdefault("marker", "circle")
|
|
54
57
|
y_kwargs.setdefault("fill_color", "red")
|
|
55
58
|
y_kwargs.setdefault("line_width", 0)
|
|
56
59
|
y_kwargs.setdefault("size", 3)
|
|
57
60
|
|
|
58
61
|
if y_hat_plot_kwargs is None:
|
|
59
62
|
y_hat_plot_kwargs = {}
|
|
63
|
+
else:
|
|
64
|
+
y_hat_plot_kwargs = y_hat_plot_kwargs.copy()
|
|
65
|
+
y_hat_plot_kwargs.setdefault("marker", "circle")
|
|
60
66
|
y_hat_plot_kwargs.setdefault("fill_color", "orange")
|
|
61
67
|
y_hat_plot_kwargs.setdefault("line_width", 0)
|
|
62
68
|
|
|
@@ -84,7 +90,7 @@ def plot_lm(
|
|
|
84
90
|
_, _, _, y_plotters = y[i]
|
|
85
91
|
_, _, _, x_plotters = x[i]
|
|
86
92
|
legend_it = []
|
|
87
|
-
observed_legend = ax_i.
|
|
93
|
+
observed_legend = ax_i.scatter(x_plotters, y_plotters, **y_kwargs)
|
|
88
94
|
legend_it.append(("Observed", [observed_legend]))
|
|
89
95
|
|
|
90
96
|
if y_hat is not None:
|
|
@@ -98,14 +104,14 @@ def plot_lm(
|
|
|
98
104
|
x_plotters_jitter = x_plotters + np.random.uniform(
|
|
99
105
|
low=-scale_high, high=scale_high, size=len(x_plotters)
|
|
100
106
|
)
|
|
101
|
-
posterior_circle = ax_i.
|
|
107
|
+
posterior_circle = ax_i.scatter(
|
|
102
108
|
x_plotters_jitter,
|
|
103
109
|
y_hat_plotters[..., j],
|
|
104
110
|
alpha=0.2,
|
|
105
111
|
**y_hat_plot_kwargs,
|
|
106
112
|
)
|
|
107
113
|
else:
|
|
108
|
-
posterior_circle = ax_i.
|
|
114
|
+
posterior_circle = ax_i.scatter(
|
|
109
115
|
x_plotters, y_hat_plotters[..., j], alpha=0.2, **y_hat_plot_kwargs
|
|
110
116
|
)
|
|
111
117
|
posterior_legend.append(posterior_circle)
|
|
@@ -71,13 +71,13 @@ def plot_mcse(
|
|
|
71
71
|
values = data[var_name].sel(**selection).values.flatten()
|
|
72
72
|
if errorbar:
|
|
73
73
|
quantile_values = _quantile(values, probs)
|
|
74
|
-
ax_.
|
|
74
|
+
ax_.scatter(probs, quantile_values, marker="dash")
|
|
75
75
|
ax_.multi_line(
|
|
76
76
|
list(zip(probs, probs)),
|
|
77
77
|
[(quant - err, quant + err) for quant, err in zip(quantile_values, x)],
|
|
78
78
|
)
|
|
79
79
|
else:
|
|
80
|
-
ax_.
|
|
80
|
+
ax_.scatter(probs, x, marker="circle")
|
|
81
81
|
if extra_methods:
|
|
82
82
|
mean_mcse_i = mean_mcse[var_name].sel(**selection).values.item()
|
|
83
83
|
sd_mcse_i = sd_mcse[var_name].sel(**selection).values.item()
|
|
@@ -121,12 +121,14 @@ def plot_pair(
|
|
|
121
121
|
)
|
|
122
122
|
|
|
123
123
|
reference_values_kwargs = _init_kwargs_dict(reference_values_kwargs)
|
|
124
|
+
reference_values_kwargs.setdefault("marker", "circle")
|
|
124
125
|
reference_values_kwargs.setdefault("line_color", "black")
|
|
125
126
|
reference_values_kwargs.setdefault("fill_color", vectorized_to_hex("C2"))
|
|
126
127
|
reference_values_kwargs.setdefault("line_width", 1)
|
|
127
128
|
reference_values_kwargs.setdefault("size", 10)
|
|
128
129
|
|
|
129
130
|
divergences_kwargs = _init_kwargs_dict(divergences_kwargs)
|
|
131
|
+
divergences_kwargs.setdefault("marker", "circle")
|
|
130
132
|
divergences_kwargs.setdefault("line_color", "black")
|
|
131
133
|
divergences_kwargs.setdefault("fill_color", vectorized_to_hex("C1"))
|
|
132
134
|
divergences_kwargs.setdefault("line_width", 1)
|
|
@@ -155,6 +157,7 @@ def plot_pair(
|
|
|
155
157
|
)
|
|
156
158
|
|
|
157
159
|
point_estimate_marker_kwargs = _init_kwargs_dict(point_estimate_marker_kwargs)
|
|
160
|
+
point_estimate_marker_kwargs.setdefault("marker", "square")
|
|
158
161
|
point_estimate_marker_kwargs.setdefault("size", markersize)
|
|
159
162
|
point_estimate_marker_kwargs.setdefault("color", "black")
|
|
160
163
|
point_estimate_kwargs.setdefault("line_color", "black")
|
|
@@ -265,9 +268,11 @@ def plot_pair(
|
|
|
265
268
|
elif j + marginals_offset > i:
|
|
266
269
|
if "scatter" in kind:
|
|
267
270
|
if divergences:
|
|
268
|
-
ax[j, i].
|
|
271
|
+
ax[j, i].scatter(
|
|
272
|
+
var1, var2, marker="circle", source=source, view=source_nondiv
|
|
273
|
+
)
|
|
269
274
|
else:
|
|
270
|
-
ax[j, i].
|
|
275
|
+
ax[j, i].scatter(var1, var2, marker="circle", source=source)
|
|
271
276
|
|
|
272
277
|
if "kde" in kind:
|
|
273
278
|
var1_kde = plotters[i][-1].flatten()
|
|
@@ -293,7 +298,7 @@ def plot_pair(
|
|
|
293
298
|
)
|
|
294
299
|
|
|
295
300
|
if divergences:
|
|
296
|
-
ax[j, i].
|
|
301
|
+
ax[j, i].scatter(
|
|
297
302
|
var1,
|
|
298
303
|
var2,
|
|
299
304
|
source=source,
|
|
@@ -306,7 +311,7 @@ def plot_pair(
|
|
|
306
311
|
var2_pe = plotters[j][-1].flatten()
|
|
307
312
|
pe_x = calculate_point_estimate(point_estimate, var1_pe)
|
|
308
313
|
pe_y = calculate_point_estimate(point_estimate, var2_pe)
|
|
309
|
-
ax[j, i].
|
|
314
|
+
ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
|
|
310
315
|
|
|
311
316
|
ax_hline = Span(
|
|
312
317
|
location=pe_y,
|
|
@@ -344,7 +349,7 @@ def plot_pair(
|
|
|
344
349
|
x = reference_values_copy[flat_var_names[j + marginals_offset]]
|
|
345
350
|
y = reference_values_copy[flat_var_names[i]]
|
|
346
351
|
if x and y:
|
|
347
|
-
ax[j, i].
|
|
352
|
+
ax[j, i].scatter(y, x, **reference_values_kwargs)
|
|
348
353
|
ax[j, i].xaxis.axis_label = flat_var_names[i]
|
|
349
354
|
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
|
|
350
355
|
|
|
@@ -313,9 +313,10 @@ def plot_ppc(
|
|
|
313
313
|
obs_yvals += np.random.uniform(
|
|
314
314
|
low=scale_low, high=scale_high, size=len(obs_vals)
|
|
315
315
|
)
|
|
316
|
-
glyph = ax_i.
|
|
316
|
+
glyph = ax_i.scatter(
|
|
317
317
|
obs_vals,
|
|
318
318
|
obs_yvals,
|
|
319
|
+
marker="circle",
|
|
319
320
|
line_color=colors[1],
|
|
320
321
|
fill_color=colors[1],
|
|
321
322
|
size=markersize,
|
|
@@ -49,6 +49,7 @@ def plot_rank(
|
|
|
49
49
|
|
|
50
50
|
if marker_vlines_kwargs is None:
|
|
51
51
|
marker_vlines_kwargs = {}
|
|
52
|
+
marker_vlines_kwargs.setdefault("marker", "circle")
|
|
52
53
|
|
|
53
54
|
if backend_kwargs is None:
|
|
54
55
|
backend_kwargs = {}
|
|
@@ -109,7 +110,7 @@ def plot_rank(
|
|
|
109
110
|
elif kind == "vlines":
|
|
110
111
|
ymin = np.full(len(all_counts), all_counts.mean())
|
|
111
112
|
for idx, counts in enumerate(all_counts):
|
|
112
|
-
ax.
|
|
113
|
+
ax.scatter(
|
|
113
114
|
bin_ary,
|
|
114
115
|
counts,
|
|
115
116
|
fill_color=colors[idx],
|
|
@@ -385,9 +385,10 @@ def _plot_chains_bokeh(
|
|
|
385
385
|
**dealiase_sel_kwargs(trace_kwargs, chain_prop, chain_idx),
|
|
386
386
|
)
|
|
387
387
|
if marker:
|
|
388
|
-
ax_trace.
|
|
388
|
+
ax_trace.scatter(
|
|
389
389
|
x=x_name,
|
|
390
390
|
y=y_name,
|
|
391
|
+
marker="circle",
|
|
391
392
|
source=cds,
|
|
392
393
|
radius=0.30,
|
|
393
394
|
alpha=0.5,
|
|
@@ -81,9 +81,10 @@ def plot_violin(
|
|
|
81
81
|
[0, 0], per[:2], line_width=linewidth * 3, line_color="black", line_cap="round"
|
|
82
82
|
)
|
|
83
83
|
ax_.line([0, 0], hdi_probs, line_width=linewidth, line_color="black", line_cap="round")
|
|
84
|
-
ax_.
|
|
84
|
+
ax_.scatter(
|
|
85
85
|
0,
|
|
86
86
|
per[-1],
|
|
87
|
+
marker="circle",
|
|
87
88
|
line_color="white",
|
|
88
89
|
fill_color="white",
|
|
89
90
|
size=linewidth * 1.5,
|
|
@@ -38,6 +38,7 @@ def plot_bpv(
|
|
|
38
38
|
plot_ref_kwargs,
|
|
39
39
|
backend_kwargs,
|
|
40
40
|
show,
|
|
41
|
+
smoothing,
|
|
41
42
|
):
|
|
42
43
|
"""Matplotlib bpv plot."""
|
|
43
44
|
if backend_kwargs is None:
|
|
@@ -87,7 +88,7 @@ def plot_bpv(
|
|
|
87
88
|
obs_vals = obs_vals.flatten()
|
|
88
89
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
|
89
90
|
|
|
90
|
-
if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
|
|
91
|
+
if (obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i") and smoothing is True:
|
|
91
92
|
obs_vals, pp_vals = smooth_data(obs_vals, pp_vals)
|
|
92
93
|
|
|
93
94
|
if kind == "p_value":
|
arviz/plots/bfplot.py
CHANGED
|
@@ -2,11 +2,9 @@
|
|
|
2
2
|
# pylint: disable=unbalanced-tuple-unpacking
|
|
3
3
|
import logging
|
|
4
4
|
|
|
5
|
-
from numpy import interp
|
|
6
|
-
|
|
7
5
|
from ..data.utils import extract
|
|
8
6
|
from .plot_utils import get_plotting_function
|
|
9
|
-
from ..stats
|
|
7
|
+
from ..stats import bayes_factor
|
|
10
8
|
|
|
11
9
|
_log = logging.getLogger(__name__)
|
|
12
10
|
|
|
@@ -94,32 +92,17 @@ def plot_bf(
|
|
|
94
92
|
>>> az.plot_bf(idata, var_name="a", ref_val=0)
|
|
95
93
|
|
|
96
94
|
"""
|
|
97
|
-
posterior = extract(idata, var_names=var_name).values
|
|
98
|
-
|
|
99
|
-
if ref_val > posterior.max() or ref_val < posterior.min():
|
|
100
|
-
_log.warning(
|
|
101
|
-
"The reference value is outside of the posterior. "
|
|
102
|
-
"This translate into infinite support for H1, which is most likely an overstatement."
|
|
103
|
-
)
|
|
104
|
-
|
|
105
|
-
if posterior.ndim > 1:
|
|
106
|
-
_log.warning("Posterior distribution has {posterior.ndim} dimensions")
|
|
107
95
|
|
|
108
96
|
if prior is None:
|
|
109
97
|
prior = extract(idata, var_names=var_name, group="prior").values
|
|
110
98
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
elif posterior.dtype.kind == "i":
|
|
118
|
-
posterior_at_ref_val = (posterior == ref_val).mean()
|
|
119
|
-
prior_at_ref_val = (prior == ref_val).mean()
|
|
99
|
+
bf, p_at_ref_val = bayes_factor(
|
|
100
|
+
idata, var_name, prior=prior, ref_val=ref_val, return_ref_vals=True
|
|
101
|
+
)
|
|
102
|
+
bf_10 = bf["BF10"]
|
|
103
|
+
bf_01 = bf["BF01"]
|
|
120
104
|
|
|
121
|
-
|
|
122
|
-
bf_01 = 1 / bf_10
|
|
105
|
+
posterior = extract(idata, var_names=var_name)
|
|
123
106
|
|
|
124
107
|
bfplot_kwargs = dict(
|
|
125
108
|
ax=ax,
|
|
@@ -128,8 +111,8 @@ def plot_bf(
|
|
|
128
111
|
prior=prior,
|
|
129
112
|
posterior=posterior,
|
|
130
113
|
ref_val=ref_val,
|
|
131
|
-
prior_at_ref_val=
|
|
132
|
-
posterior_at_ref_val=
|
|
114
|
+
prior_at_ref_val=p_at_ref_val["prior"],
|
|
115
|
+
posterior_at_ref_val=p_at_ref_val["posterior"],
|
|
133
116
|
var_name=var_name,
|
|
134
117
|
colors=colors,
|
|
135
118
|
figsize=figsize,
|