arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
"""Test Diagnostic methods"""
|
|
2
|
-
|
|
3
|
-
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
|
|
4
|
-
import numpy as np
|
|
5
|
-
import pytest
|
|
6
|
-
|
|
7
|
-
from ...data import load_arviz_data
|
|
8
|
-
from ...rcparams import rcParams
|
|
9
|
-
from ...stats import bfmi, mcse, rhat
|
|
10
|
-
from ...stats.diagnostics import _mc_error, ks_summary
|
|
11
|
-
from ...utils import Numba
|
|
12
|
-
from ..helpers import importorskip
|
|
13
|
-
from .test_diagnostics import data # pylint: disable=unused-import
|
|
14
|
-
|
|
15
|
-
importorskip("numba")
|
|
16
|
-
|
|
17
|
-
rcParams["data.load"] = "eager"
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def test_numba_bfmi():
|
|
21
|
-
"""Numba test for bfmi."""
|
|
22
|
-
state = Numba.numba_flag
|
|
23
|
-
school = load_arviz_data("centered_eight")
|
|
24
|
-
data_md = np.random.rand(100, 100, 10)
|
|
25
|
-
Numba.disable_numba()
|
|
26
|
-
non_numba = bfmi(school.posterior["mu"].values)
|
|
27
|
-
non_numba_md = bfmi(data_md)
|
|
28
|
-
Numba.enable_numba()
|
|
29
|
-
with_numba = bfmi(school.posterior["mu"].values)
|
|
30
|
-
with_numba_md = bfmi(data_md)
|
|
31
|
-
assert np.allclose(non_numba_md, with_numba_md)
|
|
32
|
-
assert np.allclose(with_numba, non_numba)
|
|
33
|
-
assert state == Numba.numba_flag
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
|
|
37
|
-
def test_numba_rhat(method):
|
|
38
|
-
"""Numba test for mcse."""
|
|
39
|
-
state = Numba.numba_flag
|
|
40
|
-
school = np.random.rand(100, 100)
|
|
41
|
-
Numba.disable_numba()
|
|
42
|
-
non_numba = rhat(school, method=method)
|
|
43
|
-
Numba.enable_numba()
|
|
44
|
-
with_numba = rhat(school, method=method)
|
|
45
|
-
assert np.allclose(with_numba, non_numba)
|
|
46
|
-
assert Numba.numba_flag == state
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
|
|
50
|
-
def test_numba_mcse(method, prob=None):
|
|
51
|
-
"""Numba test for mcse."""
|
|
52
|
-
state = Numba.numba_flag
|
|
53
|
-
school = np.random.rand(100, 100)
|
|
54
|
-
if method == "quantile":
|
|
55
|
-
prob = 0.80
|
|
56
|
-
Numba.disable_numba()
|
|
57
|
-
non_numba = mcse(school, method=method, prob=prob)
|
|
58
|
-
Numba.enable_numba()
|
|
59
|
-
with_numba = mcse(school, method=method, prob=prob)
|
|
60
|
-
assert np.allclose(with_numba, non_numba)
|
|
61
|
-
assert Numba.numba_flag == state
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
def test_ks_summary_numba():
|
|
65
|
-
"""Numba test for ks_summary."""
|
|
66
|
-
state = Numba.numba_flag
|
|
67
|
-
data = np.random.randn(100, 100)
|
|
68
|
-
Numba.disable_numba()
|
|
69
|
-
non_numba = (ks_summary(data)["Count"]).values
|
|
70
|
-
Numba.enable_numba()
|
|
71
|
-
with_numba = (ks_summary(data)["Count"]).values
|
|
72
|
-
assert np.allclose(non_numba, with_numba)
|
|
73
|
-
assert Numba.numba_flag == state
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
@pytest.mark.parametrize("batches", (1, 20))
|
|
77
|
-
@pytest.mark.parametrize("circular", (True, False))
|
|
78
|
-
def test_mcse_error_numba(batches, circular):
|
|
79
|
-
"""Numba test for mcse_error."""
|
|
80
|
-
data = np.random.randn(100, 100)
|
|
81
|
-
state = Numba.numba_flag
|
|
82
|
-
Numba.disable_numba()
|
|
83
|
-
non_numba = _mc_error(data, batches=batches, circular=circular)
|
|
84
|
-
Numba.enable_numba()
|
|
85
|
-
with_numba = _mc_error(data, batches=batches, circular=circular)
|
|
86
|
-
assert np.allclose(non_numba, with_numba)
|
|
87
|
-
assert state == Numba.numba_flag
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
import pytest
|
|
2
|
-
from _pytest.outcomes import Skipped
|
|
3
|
-
|
|
4
|
-
from ..helpers import importorskip
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def test_importorskip_local(monkeypatch):
|
|
8
|
-
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
|
|
9
|
-
monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
|
|
10
|
-
with pytest.raises(Skipped):
|
|
11
|
-
importorskip("non-existent-function")
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def test_importorskip_ci(monkeypatch):
|
|
15
|
-
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
|
|
16
|
-
monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
|
|
17
|
-
with pytest.raises(ModuleNotFoundError):
|
|
18
|
-
importorskip("non-existent-function")
|
|
@@ -1,69 +0,0 @@
|
|
|
1
|
-
"""Tests for labeller classes."""
|
|
2
|
-
|
|
3
|
-
import pytest
|
|
4
|
-
|
|
5
|
-
from ...labels import (
|
|
6
|
-
BaseLabeller,
|
|
7
|
-
DimCoordLabeller,
|
|
8
|
-
DimIdxLabeller,
|
|
9
|
-
IdxLabeller,
|
|
10
|
-
MapLabeller,
|
|
11
|
-
NoModelLabeller,
|
|
12
|
-
NoVarLabeller,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class Data:
|
|
17
|
-
def __init__(self):
|
|
18
|
-
self.sel = {
|
|
19
|
-
"instrument": "a",
|
|
20
|
-
"experiment": 3,
|
|
21
|
-
}
|
|
22
|
-
self.isel = {
|
|
23
|
-
"instrument": 0,
|
|
24
|
-
"experiment": 4,
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
@pytest.fixture
|
|
29
|
-
def multidim_sels():
|
|
30
|
-
return Data()
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Labellers:
|
|
34
|
-
def __init__(self):
|
|
35
|
-
self.labellers = {
|
|
36
|
-
"BaseLabeller": BaseLabeller(),
|
|
37
|
-
"DimCoordLabeller": DimCoordLabeller(),
|
|
38
|
-
"IdxLabeller": IdxLabeller(),
|
|
39
|
-
"DimIdxLabeller": DimIdxLabeller(),
|
|
40
|
-
"MapLabeller": MapLabeller(),
|
|
41
|
-
"NoVarLabeller": NoVarLabeller(),
|
|
42
|
-
"NoModelLabeller": NoModelLabeller(),
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@pytest.fixture
|
|
47
|
-
def labellers():
|
|
48
|
-
return Labellers()
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
@pytest.mark.parametrize(
|
|
52
|
-
"args",
|
|
53
|
-
[
|
|
54
|
-
("BaseLabeller", "theta\na, 3"),
|
|
55
|
-
("DimCoordLabeller", "theta\ninstrument: a, experiment: 3"),
|
|
56
|
-
("IdxLabeller", "theta\n0, 4"),
|
|
57
|
-
("DimIdxLabeller", "theta\ninstrument#0, experiment#4"),
|
|
58
|
-
("MapLabeller", "theta\na, 3"),
|
|
59
|
-
("NoVarLabeller", "a, 3"),
|
|
60
|
-
("NoModelLabeller", "theta\na, 3"),
|
|
61
|
-
],
|
|
62
|
-
)
|
|
63
|
-
class TestLabellers:
|
|
64
|
-
# pylint: disable=redefined-outer-name
|
|
65
|
-
def test_make_label_vert(self, args, multidim_sels, labellers):
|
|
66
|
-
name, expected_label = args
|
|
67
|
-
labeller_arg = labellers.labellers[name]
|
|
68
|
-
label = labeller_arg.make_label_vert("theta", multidim_sels.sel, multidim_sels.isel)
|
|
69
|
-
assert label == expected_label
|
|
@@ -1,342 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name
|
|
2
|
-
import importlib
|
|
3
|
-
import os
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
|
-
import pytest
|
|
7
|
-
import xarray as xr
|
|
8
|
-
|
|
9
|
-
from ...data import from_dict
|
|
10
|
-
from ...plots.backends.matplotlib import dealiase_sel_kwargs, matplotlib_kwarg_dealiaser
|
|
11
|
-
from ...plots.plot_utils import (
|
|
12
|
-
compute_ranks,
|
|
13
|
-
filter_plotters_list,
|
|
14
|
-
format_sig_figs,
|
|
15
|
-
get_plotting_function,
|
|
16
|
-
make_2d,
|
|
17
|
-
set_bokeh_circular_ticks_labels,
|
|
18
|
-
vectorized_to_hex,
|
|
19
|
-
)
|
|
20
|
-
from ...rcparams import rc_context
|
|
21
|
-
from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
|
|
22
|
-
from ...stats.density_utils import get_bins
|
|
23
|
-
from ...utils import get_coords
|
|
24
|
-
|
|
25
|
-
# Check if Bokeh is installed
|
|
26
|
-
bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
|
|
27
|
-
skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@pytest.mark.parametrize(
|
|
31
|
-
"value, default, expected",
|
|
32
|
-
[
|
|
33
|
-
(123.456, 2, 3),
|
|
34
|
-
(-123.456, 3, 3),
|
|
35
|
-
(-123.456, 4, 4),
|
|
36
|
-
(12.3456, 2, 2),
|
|
37
|
-
(1.23456, 2, 2),
|
|
38
|
-
(0.123456, 2, 2),
|
|
39
|
-
],
|
|
40
|
-
)
|
|
41
|
-
def test_format_sig_figs(value, default, expected):
|
|
42
|
-
assert format_sig_figs(value, default=default) == expected
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
@pytest.fixture(scope="function")
|
|
46
|
-
def sample_dataset():
|
|
47
|
-
mu = np.arange(1, 7).reshape(2, 3)
|
|
48
|
-
tau = np.arange(7, 13).reshape(2, 3)
|
|
49
|
-
|
|
50
|
-
chain = [0, 1]
|
|
51
|
-
draws = [0, 1, 2]
|
|
52
|
-
|
|
53
|
-
data = xr.Dataset(
|
|
54
|
-
{"mu": (["chain", "draw"], mu), "tau": (["chain", "draw"], tau)},
|
|
55
|
-
coords={"draw": draws, "chain": chain},
|
|
56
|
-
)
|
|
57
|
-
|
|
58
|
-
return mu, tau, data
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def test_make_2d():
|
|
62
|
-
"""Touches code that is hard to reach."""
|
|
63
|
-
assert len(make_2d(np.array([2, 3, 4])).shape) == 2
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def test_get_bins():
|
|
67
|
-
"""Touches code that is hard to reach."""
|
|
68
|
-
assert get_bins(np.array([1, 2, 3, 100])) is not None
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
def test_dataset_to_numpy_not_combined(sample_dataset): # pylint: disable=invalid-name
|
|
72
|
-
mu, tau, data = sample_dataset
|
|
73
|
-
var_names, data = xarray_to_ndarray(data, combined=False)
|
|
74
|
-
|
|
75
|
-
# 2 vars x 2 chains
|
|
76
|
-
assert len(var_names) == 4
|
|
77
|
-
mu_tau = np.concatenate((mu, tau), axis=0)
|
|
78
|
-
tau_mu = np.concatenate((tau, mu), axis=0)
|
|
79
|
-
deqmt = data == mu_tau
|
|
80
|
-
deqtm = data == tau_mu
|
|
81
|
-
assert deqmt.all() or deqtm.all()
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def test_dataset_to_numpy_combined(sample_dataset):
|
|
85
|
-
mu, tau, data = sample_dataset
|
|
86
|
-
var_names, data = xarray_to_ndarray(data, combined=True)
|
|
87
|
-
|
|
88
|
-
assert len(var_names) == 2
|
|
89
|
-
assert (data[var_names.index("mu")] == mu.reshape(1, 6)).all()
|
|
90
|
-
assert (data[var_names.index("tau")] == tau.reshape(1, 6)).all()
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
def test_xarray_sel_iter_ordering():
|
|
94
|
-
"""Assert that coordinate names stay the provided order"""
|
|
95
|
-
coords = list("dcba")
|
|
96
|
-
data = from_dict( # pylint: disable=no-member
|
|
97
|
-
{"x": np.random.randn(1, 100, len(coords))},
|
|
98
|
-
coords={"in_order": coords},
|
|
99
|
-
dims={"x": ["in_order"]},
|
|
100
|
-
).posterior
|
|
101
|
-
|
|
102
|
-
coord_names = [sel["in_order"] for _, sel, _ in xarray_sel_iter(data)]
|
|
103
|
-
assert coord_names == coords
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
def test_xarray_sel_iter_ordering_combined(sample_dataset): # pylint: disable=invalid-name
|
|
107
|
-
"""Assert that varname order stays consistent when chains are combined"""
|
|
108
|
-
_, _, data = sample_dataset
|
|
109
|
-
var_names = [var for (var, _, _) in xarray_sel_iter(data, var_names=None, combined=True)]
|
|
110
|
-
assert set(var_names) == {"mu", "tau"}
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def test_xarray_sel_iter_ordering_uncombined(sample_dataset): # pylint: disable=invalid-name
|
|
114
|
-
"""Assert that varname order stays consistent when chains are not combined"""
|
|
115
|
-
_, _, data = sample_dataset
|
|
116
|
-
var_names = [(var, selection) for (var, selection, _) in xarray_sel_iter(data, var_names=None)]
|
|
117
|
-
|
|
118
|
-
assert len(var_names) == 4
|
|
119
|
-
for var_name in var_names:
|
|
120
|
-
assert var_name in [
|
|
121
|
-
("mu", {"chain": 0}),
|
|
122
|
-
("mu", {"chain": 1}),
|
|
123
|
-
("tau", {"chain": 0}),
|
|
124
|
-
("tau", {"chain": 1}),
|
|
125
|
-
]
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
def test_xarray_sel_data_array(sample_dataset): # pylint: disable=invalid-name
|
|
129
|
-
"""Assert that varname order stays consistent when chains are combined
|
|
130
|
-
|
|
131
|
-
Touches code that is hard to reach.
|
|
132
|
-
"""
|
|
133
|
-
_, _, data = sample_dataset
|
|
134
|
-
var_names = [var for (var, _, _) in xarray_sel_iter(data.mu, var_names=None, combined=True)]
|
|
135
|
-
assert set(var_names) == {"mu"}
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class TestCoordsExceptions:
|
|
139
|
-
# test coord exceptions on datasets
|
|
140
|
-
def test_invalid_coord_name(self, sample_dataset): # pylint: disable=invalid-name
|
|
141
|
-
"""Assert that nicer exception appears when user enters wrong coords name"""
|
|
142
|
-
_, _, data = sample_dataset
|
|
143
|
-
coords = {"NOT_A_COORD_NAME": [1]}
|
|
144
|
-
|
|
145
|
-
with pytest.raises(
|
|
146
|
-
(KeyError, ValueError),
|
|
147
|
-
match=(
|
|
148
|
-
r"Coords "
|
|
149
|
-
r"({'NOT_A_COORD_NAME'} are invalid coordinate keys"
|
|
150
|
-
r"|should follow mapping format {coord_name:\[dim1, dim2\]})"
|
|
151
|
-
),
|
|
152
|
-
):
|
|
153
|
-
get_coords(data, coords)
|
|
154
|
-
|
|
155
|
-
def test_invalid_coord_value(self, sample_dataset): # pylint: disable=invalid-name
|
|
156
|
-
"""Assert that nicer exception appears when user enters wrong coords value"""
|
|
157
|
-
_, _, data = sample_dataset
|
|
158
|
-
coords = {"draw": [1234567]}
|
|
159
|
-
|
|
160
|
-
with pytest.raises(
|
|
161
|
-
KeyError, match=r"Coords should follow mapping format {coord_name:\[dim1, dim2\]}"
|
|
162
|
-
):
|
|
163
|
-
get_coords(data, coords)
|
|
164
|
-
|
|
165
|
-
def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=invalid-name
|
|
166
|
-
"""Assert that nicer exception appears when user enters wrong coords datatype"""
|
|
167
|
-
_, _, data = sample_dataset
|
|
168
|
-
coords = {"draw"}
|
|
169
|
-
|
|
170
|
-
with pytest.raises(TypeError):
|
|
171
|
-
get_coords(data, coords)
|
|
172
|
-
|
|
173
|
-
# test coord exceptions on dataset list
|
|
174
|
-
def test_invalid_coord_name_list(self, sample_dataset): # pylint: disable=invalid-name
|
|
175
|
-
"""Assert that nicer exception appears when user enters wrong coords name"""
|
|
176
|
-
_, _, data = sample_dataset
|
|
177
|
-
coords = {"NOT_A_COORD_NAME": [1]}
|
|
178
|
-
|
|
179
|
-
with pytest.raises(
|
|
180
|
-
(KeyError, ValueError),
|
|
181
|
-
match=(
|
|
182
|
-
r"data\[1\]:.+Coords "
|
|
183
|
-
r"({'NOT_A_COORD_NAME'} are invalid coordinate keys"
|
|
184
|
-
r"|should follow mapping format {coord_name:\[dim1, dim2\]})"
|
|
185
|
-
),
|
|
186
|
-
):
|
|
187
|
-
get_coords((data, data), ({"draw": [0, 1]}, coords))
|
|
188
|
-
|
|
189
|
-
def test_invalid_coord_value_list(self, sample_dataset): # pylint: disable=invalid-name
|
|
190
|
-
"""Assert that nicer exception appears when user enters wrong coords value"""
|
|
191
|
-
_, _, data = sample_dataset
|
|
192
|
-
coords = {"draw": [1234567]}
|
|
193
|
-
|
|
194
|
-
with pytest.raises(
|
|
195
|
-
KeyError,
|
|
196
|
-
match=r"data\[0\]:.+Coords should follow mapping format {coord_name:\[dim1, dim2\]}",
|
|
197
|
-
):
|
|
198
|
-
get_coords((data, data), (coords, {"draw": [0, 1]}))
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def test_filter_plotter_list():
|
|
202
|
-
plotters = list(range(7))
|
|
203
|
-
with rc_context({"plot.max_subplots": 10}):
|
|
204
|
-
plotters_filtered = filter_plotters_list(plotters, "")
|
|
205
|
-
assert plotters == plotters_filtered
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def test_filter_plotter_list_warning():
|
|
209
|
-
plotters = list(range(7))
|
|
210
|
-
with rc_context({"plot.max_subplots": 5}):
|
|
211
|
-
with pytest.warns(UserWarning, match="test warning"):
|
|
212
|
-
plotters_filtered = filter_plotters_list(plotters, "test warning")
|
|
213
|
-
assert len(plotters_filtered) == 5
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
217
|
-
def test_bokeh_import():
|
|
218
|
-
"""Tests that correct method is returned on bokeh import"""
|
|
219
|
-
plot = get_plotting_function("plot_dist", "distplot", "bokeh")
|
|
220
|
-
|
|
221
|
-
from ...plots.backends.bokeh.distplot import plot_dist
|
|
222
|
-
|
|
223
|
-
assert plot is plot_dist
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
@pytest.mark.parametrize(
|
|
227
|
-
"params",
|
|
228
|
-
[
|
|
229
|
-
{
|
|
230
|
-
"input": (
|
|
231
|
-
{
|
|
232
|
-
"dashes": "-",
|
|
233
|
-
},
|
|
234
|
-
"scatter",
|
|
235
|
-
),
|
|
236
|
-
"output": "linestyle",
|
|
237
|
-
},
|
|
238
|
-
{
|
|
239
|
-
"input": (
|
|
240
|
-
{"mfc": "blue", "c": "blue", "line_width": 2},
|
|
241
|
-
"plot",
|
|
242
|
-
),
|
|
243
|
-
"output": ("markerfacecolor", "color", "line_width"),
|
|
244
|
-
},
|
|
245
|
-
{"input": ({"ec": "blue", "fc": "black"}, "hist"), "output": ("edgecolor", "facecolor")},
|
|
246
|
-
{
|
|
247
|
-
"input": ({"edgecolors": "blue", "lw": 3}, "hlines"),
|
|
248
|
-
"output": ("edgecolor", "linewidth"),
|
|
249
|
-
},
|
|
250
|
-
],
|
|
251
|
-
)
|
|
252
|
-
def test_matplotlib_kwarg_dealiaser(params):
|
|
253
|
-
dealiased = matplotlib_kwarg_dealiaser(params["input"][0], kind=params["input"][1])
|
|
254
|
-
for returned in dealiased:
|
|
255
|
-
assert returned in params["output"]
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
@pytest.mark.parametrize("c_values", ["#0000ff", "blue", [0, 0, 1]])
|
|
259
|
-
def test_vectorized_to_hex_scalar(c_values):
|
|
260
|
-
output = vectorized_to_hex(c_values)
|
|
261
|
-
assert output == "#0000ff"
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
@pytest.mark.parametrize(
|
|
265
|
-
"c_values", [["blue", "blue"], ["blue", "#0000ff"], np.array([[0, 0, 1], [0, 0, 1]])]
|
|
266
|
-
)
|
|
267
|
-
def test_vectorized_to_hex_array(c_values):
|
|
268
|
-
output = vectorized_to_hex(c_values)
|
|
269
|
-
assert np.all([item == "#0000ff" for item in output])
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def test_mpl_dealiase_sel_kwargs():
|
|
273
|
-
"""Check mpl dealiase_sel_kwargs behaviour.
|
|
274
|
-
|
|
275
|
-
Makes sure kwargs are overwritten when necessary even with alias involved and that
|
|
276
|
-
they are not modified when not included in props.
|
|
277
|
-
"""
|
|
278
|
-
kwargs = {"linewidth": 3, "alpha": 0.4, "line_color": "red"}
|
|
279
|
-
props = {"lw": [1, 2, 4, 5], "linestyle": ["-", "--", ":"]}
|
|
280
|
-
res = dealiase_sel_kwargs(kwargs, props, 2)
|
|
281
|
-
assert "linewidth" in res
|
|
282
|
-
assert res["linewidth"] == 4
|
|
283
|
-
assert "linestyle" in res
|
|
284
|
-
assert res["linestyle"] == ":"
|
|
285
|
-
assert "alpha" in res
|
|
286
|
-
assert res["alpha"] == 0.4
|
|
287
|
-
assert "line_color" in res
|
|
288
|
-
assert res["line_color"] == "red"
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
292
|
-
def test_bokeh_dealiase_sel_kwargs():
|
|
293
|
-
"""Check bokeh dealiase_sel_kwargs behaviour.
|
|
294
|
-
|
|
295
|
-
Makes sure kwargs are overwritten when necessary even with alias involved and that
|
|
296
|
-
they are not modified when not included in props.
|
|
297
|
-
"""
|
|
298
|
-
from ...plots.backends.bokeh import dealiase_sel_kwargs
|
|
299
|
-
|
|
300
|
-
kwargs = {"line_width": 3, "line_alpha": 0.4, "line_color": "red"}
|
|
301
|
-
props = {"line_width": [1, 2, 4, 5], "line_dash": ["dashed", "dashed", "dashed"]}
|
|
302
|
-
res = dealiase_sel_kwargs(kwargs, props, 2)
|
|
303
|
-
assert "line_width" in res
|
|
304
|
-
assert res["line_width"] == 4
|
|
305
|
-
assert "line_dash" in res
|
|
306
|
-
assert res["line_dash"] == "dashed"
|
|
307
|
-
assert "line_alpha" in res
|
|
308
|
-
assert res["line_alpha"] == 0.4
|
|
309
|
-
assert "line_color" in res
|
|
310
|
-
assert res["line_color"] == "red"
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
@pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
|
|
314
|
-
def test_set_bokeh_circular_ticks_labels():
|
|
315
|
-
"""Assert the axes returned after placing ticks and tick labels for circular plots."""
|
|
316
|
-
import bokeh.plotting as bkp
|
|
317
|
-
|
|
318
|
-
ax = bkp.figure(x_axis_type=None, y_axis_type=None)
|
|
319
|
-
hist = np.linspace(0, 1, 10)
|
|
320
|
-
labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"]
|
|
321
|
-
ax = set_bokeh_circular_ticks_labels(ax, hist, labels)
|
|
322
|
-
renderers = ax.renderers
|
|
323
|
-
assert len(renderers) == 3
|
|
324
|
-
assert renderers[2].data_source.data["text"] == labels
|
|
325
|
-
assert len(renderers[0].data_source.data["start_angle"]) == len(labels)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
def test_compute_ranks():
|
|
329
|
-
pois_data = np.array([[5, 4, 1, 4, 0], [2, 8, 2, 1, 1]])
|
|
330
|
-
expected = np.array([[9.0, 7.0, 3.0, 8.0, 1.0], [5.0, 10.0, 6.0, 2.0, 4.0]])
|
|
331
|
-
ranks = compute_ranks(pois_data)
|
|
332
|
-
np.testing.assert_equal(ranks, expected)
|
|
333
|
-
|
|
334
|
-
norm_data = np.array(
|
|
335
|
-
[
|
|
336
|
-
[0.2644187, -1.3004813, -0.80428456, 1.01319068, 0.62631143],
|
|
337
|
-
[1.34498018, -0.13428933, -0.69855487, -0.9498981, -0.34074092],
|
|
338
|
-
]
|
|
339
|
-
)
|
|
340
|
-
expected = np.array([[7.0, 1.0, 3.0, 9.0, 8.0], [10.0, 6.0, 4.0, 2.0, 5.0]])
|
|
341
|
-
ranks = compute_ranks(norm_data)
|
|
342
|
-
np.testing.assert_equal(ranks, expected)
|