arviz 0.18.0__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +2 -1
- arviz/data/base.py +18 -7
- arviz/data/converters.py +7 -3
- arviz/data/inference_data.py +8 -0
- arviz/data/io_cmdstan.py +4 -0
- arviz/data/io_numpyro.py +1 -1
- arviz/plots/backends/bokeh/ecdfplot.py +1 -2
- arviz/plots/backends/bokeh/khatplot.py +8 -3
- arviz/plots/backends/bokeh/pairplot.py +2 -6
- arviz/plots/backends/matplotlib/ecdfplot.py +1 -2
- arviz/plots/backends/matplotlib/khatplot.py +7 -3
- arviz/plots/backends/matplotlib/traceplot.py +1 -1
- arviz/plots/bpvplot.py +2 -2
- arviz/plots/compareplot.py +4 -4
- arviz/plots/densityplot.py +1 -1
- arviz/plots/dotplot.py +2 -2
- arviz/plots/ecdfplot.py +213 -89
- arviz/plots/essplot.py +2 -2
- arviz/plots/forestplot.py +3 -3
- arviz/plots/hdiplot.py +2 -2
- arviz/plots/kdeplot.py +9 -2
- arviz/plots/khatplot.py +23 -6
- arviz/plots/loopitplot.py +2 -2
- arviz/plots/mcseplot.py +3 -1
- arviz/plots/plot_utils.py +2 -4
- arviz/plots/posteriorplot.py +1 -1
- arviz/plots/rankplot.py +2 -2
- arviz/plots/violinplot.py +1 -1
- arviz/preview.py +17 -0
- arviz/rcparams.py +27 -2
- arviz/stats/diagnostics.py +13 -9
- arviz/stats/ecdf_utils.py +168 -10
- arviz/stats/stats.py +41 -20
- arviz/stats/stats_utils.py +8 -6
- arviz/tests/base_tests/test_data.py +11 -2
- arviz/tests/base_tests/test_data_zarr.py +0 -1
- arviz/tests/base_tests/test_diagnostics_numba.py +2 -7
- arviz/tests/base_tests/test_helpers.py +2 -2
- arviz/tests/base_tests/test_plot_utils.py +5 -13
- arviz/tests/base_tests/test_plots_matplotlib.py +95 -2
- arviz/tests/base_tests/test_rcparams.py +12 -0
- arviz/tests/base_tests/test_stats.py +1 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +15 -2
- arviz/tests/base_tests/test_stats_numba.py +2 -7
- arviz/tests/base_tests/test_utils_numba.py +2 -5
- arviz/tests/external_tests/test_data_pystan.py +5 -5
- arviz/tests/helpers.py +17 -9
- arviz/utils.py +4 -0
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/METADATA +23 -19
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/RECORD +53 -52
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/WHEEL +1 -1
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/LICENSE +0 -0
- {arviz-0.18.0.dist-info → arviz-0.20.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
|
|
2
2
|
"""ArviZ is a library for exploratory analysis of Bayesian models."""
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.20.0"
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
@@ -37,6 +37,7 @@ from .stats import *
|
|
|
37
37
|
from .rcparams import rc_context, rcParams
|
|
38
38
|
from .utils import Numba, Dask, interactive_backend
|
|
39
39
|
from .wrappers import *
|
|
40
|
+
from . import preview
|
|
40
41
|
|
|
41
42
|
# add ArviZ's styles to matplotlib's styles
|
|
42
43
|
_arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
|
arviz/data/base.py
CHANGED
|
@@ -9,9 +9,13 @@ from copy import deepcopy
|
|
|
9
9
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
|
|
10
10
|
|
|
11
11
|
import numpy as np
|
|
12
|
-
import tree
|
|
13
12
|
import xarray as xr
|
|
14
13
|
|
|
14
|
+
try:
|
|
15
|
+
import tree
|
|
16
|
+
except ImportError:
|
|
17
|
+
tree = None
|
|
18
|
+
|
|
15
19
|
try:
|
|
16
20
|
import ujson as json
|
|
17
21
|
except ImportError:
|
|
@@ -89,6 +93,9 @@ def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
|
|
89
93
|
input_tree.
|
|
90
94
|
"""
|
|
91
95
|
# pylint: disable=protected-access
|
|
96
|
+
if tree is None:
|
|
97
|
+
raise ImportError("Missing optional dependency 'dm-tree'. Use pip or conda to install it")
|
|
98
|
+
|
|
92
99
|
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
|
|
93
100
|
isinstance(shallow_tree, tree.collections_abc.Mapping)
|
|
94
101
|
or tree._is_namedtuple(shallow_tree)
|
|
@@ -299,7 +306,7 @@ def numpy_to_data_array(
|
|
|
299
306
|
return xr.DataArray(ary, coords=coords, dims=dims)
|
|
300
307
|
|
|
301
308
|
|
|
302
|
-
def
|
|
309
|
+
def dict_to_dataset(
|
|
303
310
|
data,
|
|
304
311
|
*,
|
|
305
312
|
attrs=None,
|
|
@@ -312,6 +319,8 @@ def pytree_to_dataset(
|
|
|
312
319
|
):
|
|
313
320
|
"""Convert a dictionary or pytree of numpy arrays to an xarray.Dataset.
|
|
314
321
|
|
|
322
|
+
ArviZ itself supports conversion of flat dictionaries.
|
|
323
|
+
Suport for pytrees requires ``dm-tree`` which is an optional dependency.
|
|
315
324
|
See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
|
|
316
325
|
this inclues at least dictionaries and tuple types.
|
|
317
326
|
|
|
@@ -386,10 +395,12 @@ def pytree_to_dataset(
|
|
|
386
395
|
"""
|
|
387
396
|
if dims is None:
|
|
388
397
|
dims = {}
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
398
|
+
|
|
399
|
+
if tree is not None:
|
|
400
|
+
try:
|
|
401
|
+
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
|
|
402
|
+
except TypeError: # probably unsortable keys -- the function will still work if
|
|
403
|
+
pass # it is an honest dictionary.
|
|
393
404
|
|
|
394
405
|
data_vars = {
|
|
395
406
|
key: numpy_to_data_array(
|
|
@@ -406,7 +417,7 @@ def pytree_to_dataset(
|
|
|
406
417
|
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
|
|
407
418
|
|
|
408
419
|
|
|
409
|
-
|
|
420
|
+
pytree_to_dataset = dict_to_dataset
|
|
410
421
|
|
|
411
422
|
|
|
412
423
|
def make_attrs(attrs=None, library=None):
|
arviz/data/converters.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
"""High level conversion functions."""
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
|
-
import tree
|
|
5
4
|
import xarray as xr
|
|
6
5
|
|
|
6
|
+
try:
|
|
7
|
+
from tree import is_nested
|
|
8
|
+
except ImportError:
|
|
9
|
+
is_nested = lambda obj: False
|
|
10
|
+
|
|
7
11
|
from .base import dict_to_dataset
|
|
8
12
|
from .inference_data import InferenceData
|
|
9
13
|
from .io_beanmachine import from_beanmachine
|
|
@@ -107,7 +111,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
107
111
|
dataset = obj.to_dataset()
|
|
108
112
|
elif isinstance(obj, dict):
|
|
109
113
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
110
|
-
elif
|
|
114
|
+
elif is_nested(obj) and not isinstance(obj, (list, tuple)):
|
|
111
115
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
|
112
116
|
elif isinstance(obj, np.ndarray):
|
|
113
117
|
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
|
@@ -122,7 +126,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
|
|
|
122
126
|
"xarray dataarray",
|
|
123
127
|
"xarray dataset",
|
|
124
128
|
"dict",
|
|
125
|
-
"pytree",
|
|
129
|
+
"pytree (if 'dm-tree' is installed)",
|
|
126
130
|
"netcdf filename",
|
|
127
131
|
"numpy array",
|
|
128
132
|
"pystan fit",
|
arviz/data/inference_data.py
CHANGED
|
@@ -266,6 +266,14 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
266
266
|
raise KeyError(key)
|
|
267
267
|
return getattr(self, key)
|
|
268
268
|
|
|
269
|
+
def __setitem__(self, key: str, value: xr.Dataset):
|
|
270
|
+
"""Set item by key and update group list accordingly."""
|
|
271
|
+
if key.startswith(WARMUP_TAG):
|
|
272
|
+
self._groups_warmup.append(key)
|
|
273
|
+
else:
|
|
274
|
+
self._groups.append(key)
|
|
275
|
+
setattr(self, key, value)
|
|
276
|
+
|
|
269
277
|
def groups(self) -> List[str]:
|
|
270
278
|
"""Return all groups present in InferenceData object."""
|
|
271
279
|
return self._groups_all
|
arviz/data/io_cmdstan.py
CHANGED
|
@@ -738,6 +738,7 @@ def _process_configuration(comments):
|
|
|
738
738
|
elif "=" in comment:
|
|
739
739
|
match_int = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+)$", comment)
|
|
740
740
|
match_float = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+\.[0-9]+)$", comment)
|
|
741
|
+
match_str_bool = re.search(r"^(\S+)\s*=\s*(true|false)$", comment)
|
|
741
742
|
match_str = re.search(r"^(\S+)\s*=\s*(\S+)$", comment)
|
|
742
743
|
match_empty = re.search(r"^(\S+)\s*=\s*$", comment)
|
|
743
744
|
if match_int:
|
|
@@ -746,6 +747,9 @@ def _process_configuration(comments):
|
|
|
746
747
|
elif match_float:
|
|
747
748
|
key, value = match_float.group(1), match_float.group(2)
|
|
748
749
|
results[key] = float(value)
|
|
750
|
+
elif match_str_bool:
|
|
751
|
+
key, value = match_str_bool.group(1), match_str_bool.group(2)
|
|
752
|
+
results[key] = int(value == "true")
|
|
749
753
|
elif match_str:
|
|
750
754
|
key, value = match_str.group(1), match_str.group(2)
|
|
751
755
|
results[key] = value
|
arviz/data/io_numpyro.py
CHANGED
|
@@ -194,7 +194,7 @@ class NumPyroConverter:
|
|
|
194
194
|
)
|
|
195
195
|
for obs_name, log_like in log_likelihood_dict.items():
|
|
196
196
|
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
|
|
197
|
-
data[obs_name] = np.reshape(
|
|
197
|
+
data[obs_name] = np.reshape(np.asarray(log_like), shape)
|
|
198
198
|
return dict_to_dataset(
|
|
199
199
|
data,
|
|
200
200
|
library=self.numpyro,
|
|
@@ -13,7 +13,6 @@ def plot_ecdf(
|
|
|
13
13
|
x_bands,
|
|
14
14
|
lower,
|
|
15
15
|
higher,
|
|
16
|
-
confidence_bands,
|
|
17
16
|
plot_kwargs,
|
|
18
17
|
fill_kwargs,
|
|
19
18
|
plot_outline_kwargs,
|
|
@@ -58,7 +57,7 @@ def plot_ecdf(
|
|
|
58
57
|
plot_outline_kwargs.setdefault("color", to_hex("C0"))
|
|
59
58
|
plot_outline_kwargs.setdefault("alpha", 0.2)
|
|
60
59
|
|
|
61
|
-
if
|
|
60
|
+
if x_bands is not None:
|
|
62
61
|
ax.step(x_coord, y_coord, **plot_kwargs)
|
|
63
62
|
|
|
64
63
|
if fill_band:
|
|
@@ -21,6 +21,7 @@ def plot_khat(
|
|
|
21
21
|
figsize,
|
|
22
22
|
xdata,
|
|
23
23
|
khats,
|
|
24
|
+
good_k,
|
|
24
25
|
kwargs,
|
|
25
26
|
threshold,
|
|
26
27
|
coord_labels,
|
|
@@ -53,7 +54,11 @@ def plot_khat(
|
|
|
53
54
|
|
|
54
55
|
if hlines_kwargs is None:
|
|
55
56
|
hlines_kwargs = {}
|
|
56
|
-
|
|
57
|
+
|
|
58
|
+
if good_k is None:
|
|
59
|
+
good_k = 0.7
|
|
60
|
+
|
|
61
|
+
hlines_kwargs.setdefault("hlines", [0, good_k, 1])
|
|
57
62
|
|
|
58
63
|
cmap = None
|
|
59
64
|
if isinstance(color, str):
|
|
@@ -75,7 +80,7 @@ def plot_khat(
|
|
|
75
80
|
rgba_c = cmap(color)
|
|
76
81
|
|
|
77
82
|
khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
|
|
78
|
-
alphas = 0.5 + 0.2 * (khats >
|
|
83
|
+
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)
|
|
79
84
|
|
|
80
85
|
rgba_c = vectorized_to_hex(rgba_c)
|
|
81
86
|
|
|
@@ -130,7 +135,7 @@ def plot_khat(
|
|
|
130
135
|
xmax = len(khats)
|
|
131
136
|
|
|
132
137
|
if show_bins:
|
|
133
|
-
bin_edges = np.array([ymin,
|
|
138
|
+
bin_edges = np.array([ymin, good_k, 1, ymax])
|
|
134
139
|
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
|
|
135
140
|
hist, _, _ = histogram(khats, bin_edges)
|
|
136
141
|
for idx, count in enumerate(hist):
|
|
@@ -174,12 +174,8 @@ def plot_pair(
|
|
|
174
174
|
source = ColumnDataSource(data=source_dict)
|
|
175
175
|
|
|
176
176
|
if divergences:
|
|
177
|
-
source_nondiv = CDSView(
|
|
178
|
-
|
|
179
|
-
)
|
|
180
|
-
source_div = CDSView(
|
|
181
|
-
source=source, filters=[GroupFilter(column_name=divergenve_name, group="1")]
|
|
182
|
-
)
|
|
177
|
+
source_nondiv = CDSView(filter=GroupFilter(column_name=divergenve_name, group="0"))
|
|
178
|
+
source_div = CDSView(filter=GroupFilter(column_name=divergenve_name, group="1"))
|
|
183
179
|
|
|
184
180
|
def get_width_and_height(jointplot, rotate):
|
|
185
181
|
"""Compute subplots dimensions for two or more variables."""
|
|
@@ -13,7 +13,6 @@ def plot_ecdf(
|
|
|
13
13
|
x_bands,
|
|
14
14
|
lower,
|
|
15
15
|
higher,
|
|
16
|
-
confidence_bands,
|
|
17
16
|
plot_kwargs,
|
|
18
17
|
fill_kwargs,
|
|
19
18
|
plot_outline_kwargs,
|
|
@@ -59,7 +58,7 @@ def plot_ecdf(
|
|
|
59
58
|
|
|
60
59
|
ax.step(x_coord, y_coord, **plot_kwargs)
|
|
61
60
|
|
|
62
|
-
if
|
|
61
|
+
if x_bands is not None:
|
|
63
62
|
if fill_band:
|
|
64
63
|
ax.fill_between(x_bands, lower, higher, **fill_kwargs)
|
|
65
64
|
else:
|
|
@@ -20,6 +20,7 @@ def plot_khat(
|
|
|
20
20
|
figsize,
|
|
21
21
|
xdata,
|
|
22
22
|
khats,
|
|
23
|
+
good_k,
|
|
23
24
|
kwargs,
|
|
24
25
|
threshold,
|
|
25
26
|
coord_labels,
|
|
@@ -61,8 +62,11 @@ def plot_khat(
|
|
|
61
62
|
backend_kwargs.setdefault("figsize", figsize)
|
|
62
63
|
backend_kwargs["squeeze"] = True
|
|
63
64
|
|
|
65
|
+
if good_k is None:
|
|
66
|
+
good_k = 0.7
|
|
67
|
+
|
|
64
68
|
hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
|
|
65
|
-
hlines_kwargs.setdefault("hlines", [0,
|
|
69
|
+
hlines_kwargs.setdefault("hlines", [0, good_k, 1])
|
|
66
70
|
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
|
|
67
71
|
hlines_kwargs.setdefault("alpha", 0.7)
|
|
68
72
|
hlines_kwargs.setdefault("zorder", -1)
|
|
@@ -102,7 +106,7 @@ def plot_khat(
|
|
|
102
106
|
rgba_c = cmap(norm_fun(color))
|
|
103
107
|
|
|
104
108
|
khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
|
|
105
|
-
alphas = 0.5 + 0.2 * (khats >
|
|
109
|
+
alphas = 0.5 + 0.2 * (khats > good_k) + 0.3 * (khats > 1)
|
|
106
110
|
rgba_c[:, 3] = alphas
|
|
107
111
|
rgba_c = vectorized_to_hex(rgba_c)
|
|
108
112
|
kwargs["c"] = rgba_c
|
|
@@ -151,7 +155,7 @@ def plot_khat(
|
|
|
151
155
|
)
|
|
152
156
|
|
|
153
157
|
if show_bins:
|
|
154
|
-
bin_edges = np.array([ymin,
|
|
158
|
+
bin_edges = np.array([ymin, good_k, 1, ymax])
|
|
155
159
|
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
|
|
156
160
|
hist, _, _ = histogram(khats, bin_edges)
|
|
157
161
|
for idx, count in enumerate(hist):
|
|
@@ -440,7 +440,7 @@ def plot_trace(
|
|
|
440
440
|
[], [], label="combined", **dealiase_sel_kwargs(plot_kwargs, chain_prop, -1)
|
|
441
441
|
),
|
|
442
442
|
)
|
|
443
|
-
ax.figure.axes[
|
|
443
|
+
ax.figure.axes[1].legend(handles=handles, title="chain", loc="upper right")
|
|
444
444
|
|
|
445
445
|
if axes is None:
|
|
446
446
|
axes = np.array(ax.figure.axes).reshape(-1, 2)
|
arviz/plots/bpvplot.py
CHANGED
|
@@ -80,7 +80,7 @@ def plot_bpv(
|
|
|
80
80
|
hdi_prob : float, optional
|
|
81
81
|
Probability for the highest density interval for the analytical reference distribution when
|
|
82
82
|
``kind=u_values``. Should be in the interval (0, 1]. Defaults to the
|
|
83
|
-
rcParam ``stats.
|
|
83
|
+
rcParam ``stats.ci_prob``. See :ref:`this section <common_hdi_prob>` for usage examples.
|
|
84
84
|
color : str, optional
|
|
85
85
|
Matplotlib color
|
|
86
86
|
grid : tuple, optional
|
|
@@ -202,7 +202,7 @@ def plot_bpv(
|
|
|
202
202
|
raise TypeError("`reference` argument must be either `analytical`, `samples`, or `None`")
|
|
203
203
|
|
|
204
204
|
if hdi_prob is None:
|
|
205
|
-
hdi_prob = rcParams["stats.
|
|
205
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
206
206
|
elif not 1 >= hdi_prob > 0:
|
|
207
207
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
208
208
|
|
arviz/plots/compareplot.py
CHANGED
|
@@ -11,9 +11,9 @@ def plot_compare(
|
|
|
11
11
|
comp_df,
|
|
12
12
|
insample_dev=False,
|
|
13
13
|
plot_standard_error=True,
|
|
14
|
-
plot_ic_diff=
|
|
14
|
+
plot_ic_diff=False,
|
|
15
15
|
order_by_rank=True,
|
|
16
|
-
legend=
|
|
16
|
+
legend=False,
|
|
17
17
|
title=True,
|
|
18
18
|
figsize=None,
|
|
19
19
|
textsize=None,
|
|
@@ -45,12 +45,12 @@ def plot_compare(
|
|
|
45
45
|
penalization given by the effective number of parameters (p_loo or p_waic).
|
|
46
46
|
plot_standard_error : bool, default True
|
|
47
47
|
Plot the standard error of the ELPD.
|
|
48
|
-
plot_ic_diff : bool, default
|
|
48
|
+
plot_ic_diff : bool, default False
|
|
49
49
|
Plot standard error of the difference in ELPD between each model
|
|
50
50
|
and the top-ranked model.
|
|
51
51
|
order_by_rank : bool, default True
|
|
52
52
|
If True ensure the best model is used as reference.
|
|
53
|
-
legend : bool, default
|
|
53
|
+
legend : bool, default False
|
|
54
54
|
Add legend to figure.
|
|
55
55
|
figsize : (float, float), optional
|
|
56
56
|
If `None`, size is (6, num of models) inches.
|
arviz/plots/densityplot.py
CHANGED
arviz/plots/dotplot.py
CHANGED
|
@@ -67,7 +67,7 @@ def plot_dot(
|
|
|
67
67
|
The shape of the marker. Valid for matplotlib backend.
|
|
68
68
|
hdi_prob : float, optional
|
|
69
69
|
Valid only when point_interval is True. Plots HDI for chosen percentage of density.
|
|
70
|
-
Defaults to ``stats.
|
|
70
|
+
Defaults to ``stats.ci_prob`` rcParam. See :ref:`this section <common_hdi_prob>`
|
|
71
71
|
for usage examples.
|
|
72
72
|
rotated : bool, default False
|
|
73
73
|
Whether to rotate the dot plot by 90 degrees.
|
|
@@ -151,7 +151,7 @@ def plot_dot(
|
|
|
151
151
|
values.sort()
|
|
152
152
|
|
|
153
153
|
if hdi_prob is None:
|
|
154
|
-
hdi_prob = rcParams["stats.
|
|
154
|
+
hdi_prob = rcParams["stats.ci_prob"]
|
|
155
155
|
elif not 1 >= hdi_prob > 0:
|
|
156
156
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
|
157
157
|
|