arviz 0.21.0__py3-none-any.whl → 0.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +8 -3
- arviz/data/inference_data.py +37 -19
- arviz/data/io_datatree.py +2 -2
- arviz/data/io_numpyro.py +112 -4
- arviz/plots/autocorrplot.py +12 -2
- arviz/plots/backends/bokeh/hdiplot.py +7 -6
- arviz/plots/backends/bokeh/lmplot.py +19 -3
- arviz/plots/backends/bokeh/pairplot.py +18 -48
- arviz/plots/backends/matplotlib/khatplot.py +8 -1
- arviz/plots/backends/matplotlib/lmplot.py +13 -7
- arviz/plots/backends/matplotlib/pairplot.py +14 -22
- arviz/plots/kdeplot.py +4 -4
- arviz/plots/lmplot.py +41 -14
- arviz/plots/pairplot.py +10 -3
- arviz/stats/density_utils.py +1 -1
- arviz/stats/stats.py +19 -7
- arviz/tests/base_tests/test_data.py +0 -4
- arviz/tests/base_tests/test_plots_bokeh.py +60 -2
- arviz/tests/base_tests/test_plots_matplotlib.py +77 -1
- arviz/tests/base_tests/test_stats.py +42 -1
- arviz/tests/external_tests/test_data_numpyro.py +130 -3
- arviz/wrappers/base.py +1 -1
- arviz/wrappers/wrap_stan.py +1 -1
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/METADATA +7 -7
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/RECORD +28 -28
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/LICENSE +0 -0
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/WHEEL +0 -0
- {arviz-0.21.0.dist-info → arviz-0.22.0.dist-info}/top_level.txt +0 -0
arviz/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# pylint: disable=wildcard-import,invalid-name,wrong-import-position
|
|
2
2
|
"""ArviZ is a library for exploratory analysis of Bayesian models."""
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.22.0"
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
@@ -8,6 +8,7 @@ import os
|
|
|
8
8
|
from matplotlib.colors import LinearSegmentedColormap
|
|
9
9
|
from matplotlib.pyplot import style
|
|
10
10
|
import matplotlib as mpl
|
|
11
|
+
from packaging import version
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class Logger(logging.Logger):
|
|
@@ -41,8 +42,12 @@ from . import preview
|
|
|
41
42
|
|
|
42
43
|
# add ArviZ's styles to matplotlib's styles
|
|
43
44
|
_arviz_style_path = os.path.join(os.path.dirname(__file__), "plots", "styles")
|
|
44
|
-
|
|
45
|
-
style.
|
|
45
|
+
if version.parse(mpl.__version__) >= version.parse("3.11.0.dev0"):
|
|
46
|
+
style.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
47
|
+
style.reload_library()
|
|
48
|
+
else:
|
|
49
|
+
style.core.USER_LIBRARY_PATHS.append(_arviz_style_path)
|
|
50
|
+
style.core.reload_library()
|
|
46
51
|
|
|
47
52
|
|
|
48
53
|
if not logging.root.handlers:
|
arviz/data/inference_data.py
CHANGED
|
@@ -532,24 +532,27 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
532
532
|
return filename
|
|
533
533
|
|
|
534
534
|
def to_datatree(self):
|
|
535
|
-
"""Convert InferenceData object to a :class:`~
|
|
535
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`."""
|
|
536
536
|
try:
|
|
537
|
-
from
|
|
538
|
-
except
|
|
539
|
-
raise
|
|
540
|
-
"
|
|
537
|
+
from xarray import DataTree
|
|
538
|
+
except ImportError as err:
|
|
539
|
+
raise ImportError(
|
|
540
|
+
"xarray must be have DataTree in order to use InferenceData.to_datatree. "
|
|
541
|
+
"Update to xarray>=2024.11.0"
|
|
541
542
|
) from err
|
|
542
543
|
return DataTree.from_dict({group: ds for group, ds in self.items()})
|
|
543
544
|
|
|
544
545
|
@staticmethod
|
|
545
546
|
def from_datatree(datatree):
|
|
546
|
-
"""Create an InferenceData object from a :class:`~
|
|
547
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
547
548
|
|
|
548
549
|
Parameters
|
|
549
550
|
----------
|
|
550
551
|
datatree : DataTree
|
|
551
552
|
"""
|
|
552
|
-
return InferenceData(
|
|
553
|
+
return InferenceData(
|
|
554
|
+
**{group: child.to_dataset() for group, child in datatree.children.items()}
|
|
555
|
+
)
|
|
553
556
|
|
|
554
557
|
def to_dict(self, groups=None, filter_groups=None):
|
|
555
558
|
"""Convert InferenceData to a dictionary following xarray naming conventions.
|
|
@@ -797,12 +800,20 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
797
800
|
----------
|
|
798
801
|
https://zarr.readthedocs.io/
|
|
799
802
|
"""
|
|
800
|
-
try:
|
|
803
|
+
try:
|
|
801
804
|
import zarr
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
raise ImportError(
|
|
805
|
+
except ImportError as err:
|
|
806
|
+
raise ImportError("'to_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
807
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
808
|
+
raise ImportError(
|
|
809
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'to_zarr'"
|
|
810
|
+
)
|
|
811
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
812
|
+
raise ImportError(
|
|
813
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
814
|
+
"'dt = InfereceData.to_datatree' followed by 'dt.to_zarr()' "
|
|
815
|
+
"(needs xarray>=2024.11.0)"
|
|
816
|
+
)
|
|
806
817
|
|
|
807
818
|
# Check store type and create store if necessary
|
|
808
819
|
if store is None:
|
|
@@ -851,10 +862,18 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
851
862
|
"""
|
|
852
863
|
try:
|
|
853
864
|
import zarr
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
raise ImportError(
|
|
865
|
+
except ImportError as err:
|
|
866
|
+
raise ImportError("'from_zarr' method needs Zarr (>=2.5.0,<3) installed.") from err
|
|
867
|
+
if version.parse(zarr.__version__) < version.parse("2.5.0"):
|
|
868
|
+
raise ImportError(
|
|
869
|
+
"Found zarr<2.5.0, please upgrade to a zarr (>=2.5.0,<3) to use 'from_zarr'"
|
|
870
|
+
)
|
|
871
|
+
if version.parse(zarr.__version__) >= version.parse("3.0.0.dev0"):
|
|
872
|
+
raise ImportError(
|
|
873
|
+
"Found zarr>=3, which is not supported by ArviZ. Instead, you can use "
|
|
874
|
+
"'xarray.open_datatree' followed by 'arviz.InferenceData.from_datatree' "
|
|
875
|
+
"(needs xarray>=2024.11.0)"
|
|
876
|
+
)
|
|
858
877
|
|
|
859
878
|
# Check store type and create store if necessary
|
|
860
879
|
if isinstance(store, str):
|
|
@@ -1531,9 +1550,8 @@ class InferenceData(Mapping[str, xr.Dataset]):
|
|
|
1531
1550
|
import xarray as xr
|
|
1532
1551
|
from xarray_einstats.stats import XrDiscreteRV
|
|
1533
1552
|
from scipy.stats import poisson
|
|
1534
|
-
dist = XrDiscreteRV(poisson)
|
|
1535
|
-
log_lik =
|
|
1536
|
-
log_lik["home_points"] = dist.logpmf(obs["home_points"], np.exp(post["atts"]))
|
|
1553
|
+
dist = XrDiscreteRV(poisson, np.exp(post["atts"]))
|
|
1554
|
+
log_lik = dist.logpmf(obs["home_points"]).to_dataset(name="home_points")
|
|
1537
1555
|
idata2.add_groups({"log_likelihood": log_lik})
|
|
1538
1556
|
idata2
|
|
1539
1557
|
|
arviz/data/io_datatree.py
CHANGED
|
@@ -4,7 +4,7 @@ from .inference_data import InferenceData
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def to_datatree(data):
|
|
7
|
-
"""Convert InferenceData object to a :class:`~
|
|
7
|
+
"""Convert InferenceData object to a :class:`~xarray.DataTree`.
|
|
8
8
|
|
|
9
9
|
Parameters
|
|
10
10
|
----------
|
|
@@ -14,7 +14,7 @@ def to_datatree(data):
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def from_datatree(datatree):
|
|
17
|
-
"""Create an InferenceData object from a :class:`~
|
|
17
|
+
"""Create an InferenceData object from a :class:`~xarray.DataTree`.
|
|
18
18
|
|
|
19
19
|
Parameters
|
|
20
20
|
----------
|
arviz/data/io_numpyro.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
"""NumPyro-specific conversion code."""
|
|
2
2
|
|
|
3
|
+
from collections import defaultdict
|
|
3
4
|
import logging
|
|
4
|
-
from typing import Callable, Optional
|
|
5
|
+
from typing import Any, Callable, Optional, Dict, List, Tuple
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
|
|
@@ -13,6 +14,70 @@ from .inference_data import InferenceData
|
|
|
13
14
|
_log = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
17
|
+
def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
|
|
18
|
+
merged = defaultdict(list)
|
|
19
|
+
|
|
20
|
+
for k, v in dims_a.items():
|
|
21
|
+
merged[k].extend(v)
|
|
22
|
+
|
|
23
|
+
for k, v in dims_b.items():
|
|
24
|
+
merged[k].extend(v)
|
|
25
|
+
|
|
26
|
+
# Convert back to a regular dict
|
|
27
|
+
return dict(merged)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def infer_dims(
|
|
31
|
+
model: Callable,
|
|
32
|
+
model_args: Optional[Tuple[Any, ...]] = None,
|
|
33
|
+
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
34
|
+
) -> Dict[str, List[str]]:
|
|
35
|
+
|
|
36
|
+
from numpyro import handlers, distributions as dist
|
|
37
|
+
from numpyro.ops.pytree import PytreeTrace
|
|
38
|
+
from numpyro.infer.initialization import init_to_sample
|
|
39
|
+
import jax
|
|
40
|
+
|
|
41
|
+
model_args = tuple() if model_args is None else model_args
|
|
42
|
+
model_kwargs = dict() if model_args is None else model_kwargs
|
|
43
|
+
|
|
44
|
+
def _get_dist_name(fn):
|
|
45
|
+
if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
|
|
46
|
+
return _get_dist_name(fn.base_dist)
|
|
47
|
+
return type(fn).__name__
|
|
48
|
+
|
|
49
|
+
def get_trace():
|
|
50
|
+
# We use `init_to_sample` to get around ImproperUniform distribution,
|
|
51
|
+
# which does not have `sample` method.
|
|
52
|
+
subs_model = handlers.substitute(
|
|
53
|
+
handlers.seed(model, 0),
|
|
54
|
+
substitute_fn=init_to_sample,
|
|
55
|
+
)
|
|
56
|
+
trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
|
|
57
|
+
# Work around an issue where jax.eval_shape does not work
|
|
58
|
+
# for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
|
|
59
|
+
# Here we will remove `fn` and store its name in the trace.
|
|
60
|
+
for _, site in trace.items():
|
|
61
|
+
if site["type"] == "sample":
|
|
62
|
+
site["fn_name"] = _get_dist_name(site.pop("fn"))
|
|
63
|
+
elif site["type"] == "deterministic":
|
|
64
|
+
site["fn_name"] = "Deterministic"
|
|
65
|
+
return PytreeTrace(trace)
|
|
66
|
+
|
|
67
|
+
# We use eval_shape to avoid any array computation.
|
|
68
|
+
trace = jax.eval_shape(get_trace).trace
|
|
69
|
+
|
|
70
|
+
named_dims = {}
|
|
71
|
+
|
|
72
|
+
for name, site in trace.items():
|
|
73
|
+
batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
|
|
74
|
+
event_dims = list(site.get("infer", {}).get("event_dims", []))
|
|
75
|
+
if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
|
|
76
|
+
named_dims[name] = batch_dims + event_dims
|
|
77
|
+
|
|
78
|
+
return named_dims
|
|
79
|
+
|
|
80
|
+
|
|
16
81
|
class NumPyroConverter:
|
|
17
82
|
"""Encapsulate NumPyro specific logic."""
|
|
18
83
|
|
|
@@ -36,6 +101,7 @@ class NumPyroConverter:
|
|
|
36
101
|
coords=None,
|
|
37
102
|
dims=None,
|
|
38
103
|
pred_dims=None,
|
|
104
|
+
extra_event_dims=None,
|
|
39
105
|
num_chains=1,
|
|
40
106
|
):
|
|
41
107
|
"""Convert NumPyro data into an InferenceData object.
|
|
@@ -58,9 +124,12 @@ class NumPyroConverter:
|
|
|
58
124
|
coords : dict[str] -> list[str]
|
|
59
125
|
Map of dimensions to coordinates
|
|
60
126
|
dims : dict[str] -> list[str]
|
|
61
|
-
Map variable names to their coordinates
|
|
127
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
62
128
|
pred_dims: dict
|
|
63
129
|
Dims for predictions data. Map variable names to their coordinates.
|
|
130
|
+
extra_event_dims: dict
|
|
131
|
+
Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
|
|
132
|
+
their coordinates.
|
|
64
133
|
num_chains: int
|
|
65
134
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
66
135
|
"""
|
|
@@ -80,6 +149,7 @@ class NumPyroConverter:
|
|
|
80
149
|
self.coords = coords
|
|
81
150
|
self.dims = dims
|
|
82
151
|
self.pred_dims = pred_dims
|
|
152
|
+
self.extra_event_dims = extra_event_dims
|
|
83
153
|
self.numpyro = numpyro
|
|
84
154
|
|
|
85
155
|
def arbitrary_element(dct):
|
|
@@ -107,6 +177,10 @@ class NumPyroConverter:
|
|
|
107
177
|
# model arguments and keyword arguments
|
|
108
178
|
self._args = self.posterior._args # pylint: disable=protected-access
|
|
109
179
|
self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
|
|
180
|
+
self.dims = self.dims if self.dims is not None else self.infer_dims()
|
|
181
|
+
self.pred_dims = (
|
|
182
|
+
self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
|
|
183
|
+
)
|
|
110
184
|
else:
|
|
111
185
|
self.nchains = num_chains
|
|
112
186
|
get_from = None
|
|
@@ -325,6 +399,23 @@ class NumPyroConverter:
|
|
|
325
399
|
}
|
|
326
400
|
)
|
|
327
401
|
|
|
402
|
+
@requires("posterior")
|
|
403
|
+
@requires("model")
|
|
404
|
+
def infer_dims(self) -> Dict[str, List[str]]:
|
|
405
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
406
|
+
if self.extra_event_dims:
|
|
407
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
408
|
+
return dims
|
|
409
|
+
|
|
410
|
+
@requires("posterior")
|
|
411
|
+
@requires("model")
|
|
412
|
+
@requires("predictions")
|
|
413
|
+
def infer_pred_dims(self) -> Dict[str, List[str]]:
|
|
414
|
+
dims = infer_dims(self.model, self._args, self._kwargs)
|
|
415
|
+
if self.extra_event_dims:
|
|
416
|
+
dims = _add_dims(dims, self.extra_event_dims)
|
|
417
|
+
return dims
|
|
418
|
+
|
|
328
419
|
|
|
329
420
|
def from_numpyro(
|
|
330
421
|
posterior=None,
|
|
@@ -339,10 +430,25 @@ def from_numpyro(
|
|
|
339
430
|
coords=None,
|
|
340
431
|
dims=None,
|
|
341
432
|
pred_dims=None,
|
|
433
|
+
extra_event_dims=None,
|
|
342
434
|
num_chains=1,
|
|
343
435
|
):
|
|
344
436
|
"""Convert NumPyro data into an InferenceData object.
|
|
345
437
|
|
|
438
|
+
If no dims are provided, this will infer batch dim names from NumPyro model plates.
|
|
439
|
+
For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
|
|
440
|
+
can be provided in numpyro.sample, i.e.::
|
|
441
|
+
|
|
442
|
+
# equivalent to dims entry, {"gamma": ["groups"]}
|
|
443
|
+
gamma = numpyro.sample(
|
|
444
|
+
"gamma",
|
|
445
|
+
dist.ZeroSumNormal(1, event_shape=(n_groups,)),
|
|
446
|
+
infer={"event_dims":["groups"]}
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
There is also an additional `extra_event_dims` input to cover any edge cases, for instance
|
|
450
|
+
deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
|
|
451
|
+
|
|
346
452
|
For a usage example read the
|
|
347
453
|
:ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
|
|
348
454
|
|
|
@@ -364,9 +470,10 @@ def from_numpyro(
|
|
|
364
470
|
coords : dict[str] -> list[str]
|
|
365
471
|
Map of dimensions to coordinates
|
|
366
472
|
dims : dict[str] -> list[str]
|
|
367
|
-
Map variable names to their coordinates
|
|
473
|
+
Map variable names to their coordinates. Will be inferred if they are not provided.
|
|
368
474
|
pred_dims: dict
|
|
369
|
-
Dims for predictions data. Map variable names to their coordinates.
|
|
475
|
+
Dims for predictions data. Map variable names to their coordinates. Default behavior is to
|
|
476
|
+
infer dims if this is not provided
|
|
370
477
|
num_chains: int
|
|
371
478
|
Number of chains used for sampling. Ignored if posterior is present.
|
|
372
479
|
"""
|
|
@@ -382,5 +489,6 @@ def from_numpyro(
|
|
|
382
489
|
coords=coords,
|
|
383
490
|
dims=dims,
|
|
384
491
|
pred_dims=pred_dims,
|
|
492
|
+
extra_event_dims=extra_event_dims,
|
|
385
493
|
num_chains=num_chains,
|
|
386
494
|
).to_inference_data()
|
arviz/plots/autocorrplot.py
CHANGED
|
@@ -4,7 +4,7 @@ from ..data import convert_to_dataset
|
|
|
4
4
|
from ..labels import BaseLabeller
|
|
5
5
|
from ..sel_utils import xarray_var_iter
|
|
6
6
|
from ..rcparams import rcParams
|
|
7
|
-
from ..utils import _var_names
|
|
7
|
+
from ..utils import _var_names, get_coords
|
|
8
8
|
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
|
|
9
9
|
|
|
10
10
|
|
|
@@ -14,6 +14,7 @@ def plot_autocorr(
|
|
|
14
14
|
filter_vars=None,
|
|
15
15
|
max_lag=None,
|
|
16
16
|
combined=False,
|
|
17
|
+
coords=None,
|
|
17
18
|
grid=None,
|
|
18
19
|
figsize=None,
|
|
19
20
|
textsize=None,
|
|
@@ -42,6 +43,8 @@ def plot_autocorr(
|
|
|
42
43
|
interpret `var_names` as substrings of the real variables names. If "regex",
|
|
43
44
|
interpret `var_names` as regular expressions on the real variables names. See
|
|
44
45
|
:ref:`this section <common_filter_vars>` for usage examples.
|
|
46
|
+
coords: mapping, optional
|
|
47
|
+
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
|
|
45
48
|
max_lag : int, optional
|
|
46
49
|
Maximum lag to calculate autocorrelation. By Default, the plot displays the
|
|
47
50
|
first 100 lag or the total number of draws, whichever is smaller.
|
|
@@ -124,11 +127,18 @@ def plot_autocorr(
|
|
|
124
127
|
if max_lag is None:
|
|
125
128
|
max_lag = min(100, data["draw"].shape[0])
|
|
126
129
|
|
|
130
|
+
if coords is None:
|
|
131
|
+
coords = {}
|
|
132
|
+
|
|
127
133
|
if labeller is None:
|
|
128
134
|
labeller = BaseLabeller()
|
|
129
135
|
|
|
130
136
|
plotters = filter_plotters_list(
|
|
131
|
-
list(
|
|
137
|
+
list(
|
|
138
|
+
xarray_var_iter(
|
|
139
|
+
get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
|
|
140
|
+
)
|
|
141
|
+
),
|
|
132
142
|
"plot_autocorr",
|
|
133
143
|
)
|
|
134
144
|
rows, cols = default_grid(len(plotters), grid=grid)
|
|
@@ -21,9 +21,13 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
|
|
|
21
21
|
plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
|
|
22
22
|
plot_kwargs.setdefault("alpha", 0)
|
|
23
23
|
|
|
24
|
-
fill_kwargs = {} if fill_kwargs is None else fill_kwargs
|
|
25
|
-
|
|
26
|
-
fill_kwargs
|
|
24
|
+
fill_kwargs = {} if fill_kwargs is None else fill_kwargs.copy()
|
|
25
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
26
|
+
if "color" in fill_kwargs and "fill_color" not in fill_kwargs:
|
|
27
|
+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.pop("color"))
|
|
28
|
+
else:
|
|
29
|
+
fill_kwargs["fill_color"] = vectorized_to_hex(fill_kwargs.get("fill_color", color))
|
|
30
|
+
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0.5))
|
|
27
31
|
|
|
28
32
|
figsize, *_ = _scale_fig_size(figsize, None)
|
|
29
33
|
|
|
@@ -38,9 +42,6 @@ def plot_hdi(ax, x_data, y_data, color, figsize, plot_kwargs, fill_kwargs, backe
|
|
|
38
42
|
plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
|
|
39
43
|
plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))
|
|
40
44
|
|
|
41
|
-
fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
|
|
42
|
-
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))
|
|
43
|
-
|
|
44
45
|
ax.patch(
|
|
45
46
|
np.concatenate((x_data, x_data[::-1])),
|
|
46
47
|
np.concatenate((y_data[:, 0], y_data[:, 1][::-1])),
|
|
@@ -68,7 +68,13 @@ def plot_lm(
|
|
|
68
68
|
|
|
69
69
|
if y_hat_fill_kwargs is None:
|
|
70
70
|
y_hat_fill_kwargs = {}
|
|
71
|
-
|
|
71
|
+
else:
|
|
72
|
+
y_hat_fill_kwargs = y_hat_fill_kwargs.copy()
|
|
73
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
74
|
+
if "color" in y_hat_fill_kwargs and "fill_color" not in y_hat_fill_kwargs:
|
|
75
|
+
y_hat_fill_kwargs["fill_color"] = y_hat_fill_kwargs.pop("color")
|
|
76
|
+
y_hat_fill_kwargs.setdefault("fill_color", "orange")
|
|
77
|
+
y_hat_fill_kwargs.setdefault("fill_alpha", 0.5)
|
|
72
78
|
|
|
73
79
|
if y_model_plot_kwargs is None:
|
|
74
80
|
y_model_plot_kwargs = {}
|
|
@@ -78,8 +84,13 @@ def plot_lm(
|
|
|
78
84
|
|
|
79
85
|
if y_model_fill_kwargs is None:
|
|
80
86
|
y_model_fill_kwargs = {}
|
|
81
|
-
|
|
82
|
-
|
|
87
|
+
else:
|
|
88
|
+
y_model_fill_kwargs = y_model_fill_kwargs.copy()
|
|
89
|
+
# Convert matplotlib color to bokeh fill_color if needed
|
|
90
|
+
if "color" in y_model_fill_kwargs and "fill_color" not in y_model_fill_kwargs:
|
|
91
|
+
y_model_fill_kwargs["fill_color"] = y_model_fill_kwargs.pop("color")
|
|
92
|
+
y_model_fill_kwargs.setdefault("fill_color", "black")
|
|
93
|
+
y_model_fill_kwargs.setdefault("fill_alpha", 0.5)
|
|
83
94
|
|
|
84
95
|
if y_model_mean_kwargs is None:
|
|
85
96
|
y_model_mean_kwargs = {}
|
|
@@ -149,6 +160,11 @@ def plot_lm(
|
|
|
149
160
|
)
|
|
150
161
|
|
|
151
162
|
y_model_mean = np.mean(y_model_plotters, axis=(0, 1))
|
|
163
|
+
# Plot mean line across all x values instead of just edges
|
|
164
|
+
mean_legend = ax_i.line(x_plotters, y_model_mean, **y_model_mean_kwargs)
|
|
165
|
+
legend_it.append(("Mean", [mean_legend]))
|
|
166
|
+
continue # Skip the edge plotting since we plotted full line
|
|
167
|
+
|
|
152
168
|
x_plotters_edge = [min(x_plotters), max(x_plotters)]
|
|
153
169
|
y_model_mean_edge = [min(y_model_mean), max(y_model_mean)]
|
|
154
170
|
mean_legend = ax_i.line(x_plotters_edge, y_model_mean_edge, **y_model_mean_kwargs)
|
|
@@ -37,6 +37,8 @@ def plot_pair(
|
|
|
37
37
|
diverging_mask,
|
|
38
38
|
divergences_kwargs,
|
|
39
39
|
flat_var_names,
|
|
40
|
+
flat_ref_slices,
|
|
41
|
+
flat_var_labels,
|
|
40
42
|
backend_kwargs,
|
|
41
43
|
marginal_kwargs,
|
|
42
44
|
show,
|
|
@@ -72,50 +74,12 @@ def plot_pair(
|
|
|
72
74
|
kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)
|
|
73
75
|
|
|
74
76
|
if reference_values:
|
|
75
|
-
|
|
76
|
-
label = []
|
|
77
|
-
for variable in list(reference_values.keys()):
|
|
78
|
-
if " " in variable:
|
|
79
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
80
|
-
else:
|
|
81
|
-
variable_copy = variable
|
|
82
|
-
|
|
83
|
-
label.append(variable_copy)
|
|
84
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
85
|
-
|
|
86
|
-
difference = set(flat_var_names).difference(set(label))
|
|
87
|
-
|
|
88
|
-
if difference:
|
|
89
|
-
warn = [diff.replace("\n", " ", 1) for diff in difference]
|
|
90
|
-
warnings.warn(
|
|
91
|
-
"Argument reference_values does not include reference value for: {}".format(
|
|
92
|
-
", ".join(warn)
|
|
93
|
-
),
|
|
94
|
-
UserWarning,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
if reference_values:
|
|
98
|
-
reference_values_copy = {}
|
|
99
|
-
label = []
|
|
100
|
-
for variable in list(reference_values.keys()):
|
|
101
|
-
if " " in variable:
|
|
102
|
-
variable_copy = variable.replace(" ", "\n", 1)
|
|
103
|
-
else:
|
|
104
|
-
variable_copy = variable
|
|
105
|
-
|
|
106
|
-
label.append(variable_copy)
|
|
107
|
-
reference_values_copy[variable_copy] = reference_values[variable]
|
|
108
|
-
|
|
109
|
-
difference = set(flat_var_names).difference(set(label))
|
|
110
|
-
|
|
111
|
-
for dif in difference:
|
|
112
|
-
reference_values_copy[dif] = None
|
|
77
|
+
difference = set(flat_var_names).difference(set(reference_values.keys()))
|
|
113
78
|
|
|
114
79
|
if difference:
|
|
115
|
-
warn = [dif.replace("\n", " ", 1) for dif in difference]
|
|
116
80
|
warnings.warn(
|
|
117
81
|
"Argument reference_values does not include reference value for: {}".format(
|
|
118
|
-
", ".join(
|
|
82
|
+
", ".join(difference)
|
|
119
83
|
),
|
|
120
84
|
UserWarning,
|
|
121
85
|
)
|
|
@@ -262,8 +226,8 @@ def plot_pair(
|
|
|
262
226
|
**marginal_kwargs,
|
|
263
227
|
)
|
|
264
228
|
|
|
265
|
-
ax[j, i].xaxis.axis_label =
|
|
266
|
-
ax[j, i].yaxis.axis_label =
|
|
229
|
+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
|
|
230
|
+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
|
|
267
231
|
|
|
268
232
|
elif j + marginals_offset > i:
|
|
269
233
|
if "scatter" in kind:
|
|
@@ -346,12 +310,18 @@ def plot_pair(
|
|
|
346
310
|
ax[-1, -1].add_layout(ax_pe_hline)
|
|
347
311
|
|
|
348
312
|
if reference_values:
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
if
|
|
352
|
-
ax[j, i].scatter(
|
|
353
|
-
|
|
354
|
-
|
|
313
|
+
x_name = flat_var_names[j + marginals_offset]
|
|
314
|
+
y_name = flat_var_names[i]
|
|
315
|
+
if (x_name not in difference) and (y_name not in difference):
|
|
316
|
+
ax[j, i].scatter(
|
|
317
|
+
np.array(reference_values[y_name])[flat_ref_slices[i]],
|
|
318
|
+
np.array(reference_values[x_name])[
|
|
319
|
+
flat_ref_slices[j + marginals_offset]
|
|
320
|
+
],
|
|
321
|
+
**reference_values_kwargs,
|
|
322
|
+
)
|
|
323
|
+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
|
|
324
|
+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
|
|
355
325
|
|
|
356
326
|
show_layout(ax, show)
|
|
357
327
|
|
|
@@ -7,6 +7,7 @@ from matplotlib import cm
|
|
|
7
7
|
import matplotlib.pyplot as plt
|
|
8
8
|
import numpy as np
|
|
9
9
|
from matplotlib.colors import to_rgba_array
|
|
10
|
+
from packaging import version
|
|
10
11
|
|
|
11
12
|
from ....stats.density_utils import histogram
|
|
12
13
|
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
|
|
@@ -39,7 +40,13 @@ def plot_khat(
|
|
|
39
40
|
show,
|
|
40
41
|
):
|
|
41
42
|
"""Matplotlib khat plot."""
|
|
42
|
-
if
|
|
43
|
+
if version.parse(mpl.__version__) >= version.parse("3.9.0.dev0"):
|
|
44
|
+
interactive_backends = mpl.backends.backend_registry.list_builtin(
|
|
45
|
+
mpl.backends.BackendFilter.INTERACTIVE
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
interactive_backends = mpl.rcsetup.interactive_bk
|
|
49
|
+
if hover_label and mpl.get_backend() not in interactive_backends:
|
|
43
50
|
hover_label = False
|
|
44
51
|
warnings.warn(
|
|
45
52
|
"hover labels are only available with interactive backends. To switch to an "
|
|
@@ -115,12 +115,18 @@ def plot_lm(
|
|
|
115
115
|
|
|
116
116
|
if y_model is not None:
|
|
117
117
|
_, _, _, y_model_plotters = y_model[i]
|
|
118
|
+
|
|
118
119
|
if kind_model == "lines":
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
120
|
+
# y_model_plotters should be (points, samples)
|
|
121
|
+
y_points = y_model_plotters.shape[0]
|
|
122
|
+
if x_plotters.shape[0] == y_points:
|
|
123
|
+
for j in range(num_samples):
|
|
124
|
+
ax_i.plot(x_plotters, y_model_plotters[:, j], **y_model_plot_kwargs)
|
|
125
|
+
|
|
126
|
+
ax_i.plot([], **y_model_plot_kwargs, label="Uncertainty in mean")
|
|
127
|
+
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
128
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
122
129
|
|
|
123
|
-
y_model_mean = np.mean(y_model_plotters, axis=1)
|
|
124
130
|
else:
|
|
125
131
|
plot_hdi(
|
|
126
132
|
x_plotters,
|
|
@@ -128,10 +134,10 @@ def plot_lm(
|
|
|
128
134
|
fill_kwargs=y_model_fill_kwargs,
|
|
129
135
|
ax=ax_i,
|
|
130
136
|
)
|
|
131
|
-
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
132
137
|
|
|
133
|
-
|
|
134
|
-
|
|
138
|
+
ax_i.plot([], color=y_model_fill_kwargs["color"], label="Uncertainty in mean")
|
|
139
|
+
y_model_mean = np.mean(y_model_plotters, axis=0)
|
|
140
|
+
ax_i.plot(x_plotters, y_model_mean, **y_model_mean_kwargs, label="Mean")
|
|
135
141
|
|
|
136
142
|
if legend:
|
|
137
143
|
ax_i.legend(fontsize=xt_labelsize, loc="upper left")
|