arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,87 +0,0 @@
1
- """Test Diagnostic methods"""
2
-
3
- # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
4
- import numpy as np
5
- import pytest
6
-
7
- from ...data import load_arviz_data
8
- from ...rcparams import rcParams
9
- from ...stats import bfmi, mcse, rhat
10
- from ...stats.diagnostics import _mc_error, ks_summary
11
- from ...utils import Numba
12
- from ..helpers import importorskip
13
- from .test_diagnostics import data # pylint: disable=unused-import
14
-
15
- importorskip("numba")
16
-
17
- rcParams["data.load"] = "eager"
18
-
19
-
20
- def test_numba_bfmi():
21
- """Numba test for bfmi."""
22
- state = Numba.numba_flag
23
- school = load_arviz_data("centered_eight")
24
- data_md = np.random.rand(100, 100, 10)
25
- Numba.disable_numba()
26
- non_numba = bfmi(school.posterior["mu"].values)
27
- non_numba_md = bfmi(data_md)
28
- Numba.enable_numba()
29
- with_numba = bfmi(school.posterior["mu"].values)
30
- with_numba_md = bfmi(data_md)
31
- assert np.allclose(non_numba_md, with_numba_md)
32
- assert np.allclose(with_numba, non_numba)
33
- assert state == Numba.numba_flag
34
-
35
-
36
- @pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
37
- def test_numba_rhat(method):
38
- """Numba test for mcse."""
39
- state = Numba.numba_flag
40
- school = np.random.rand(100, 100)
41
- Numba.disable_numba()
42
- non_numba = rhat(school, method=method)
43
- Numba.enable_numba()
44
- with_numba = rhat(school, method=method)
45
- assert np.allclose(with_numba, non_numba)
46
- assert Numba.numba_flag == state
47
-
48
-
49
- @pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
50
- def test_numba_mcse(method, prob=None):
51
- """Numba test for mcse."""
52
- state = Numba.numba_flag
53
- school = np.random.rand(100, 100)
54
- if method == "quantile":
55
- prob = 0.80
56
- Numba.disable_numba()
57
- non_numba = mcse(school, method=method, prob=prob)
58
- Numba.enable_numba()
59
- with_numba = mcse(school, method=method, prob=prob)
60
- assert np.allclose(with_numba, non_numba)
61
- assert Numba.numba_flag == state
62
-
63
-
64
- def test_ks_summary_numba():
65
- """Numba test for ks_summary."""
66
- state = Numba.numba_flag
67
- data = np.random.randn(100, 100)
68
- Numba.disable_numba()
69
- non_numba = (ks_summary(data)["Count"]).values
70
- Numba.enable_numba()
71
- with_numba = (ks_summary(data)["Count"]).values
72
- assert np.allclose(non_numba, with_numba)
73
- assert Numba.numba_flag == state
74
-
75
-
76
- @pytest.mark.parametrize("batches", (1, 20))
77
- @pytest.mark.parametrize("circular", (True, False))
78
- def test_mcse_error_numba(batches, circular):
79
- """Numba test for mcse_error."""
80
- data = np.random.randn(100, 100)
81
- state = Numba.numba_flag
82
- Numba.disable_numba()
83
- non_numba = _mc_error(data, batches=batches, circular=circular)
84
- Numba.enable_numba()
85
- with_numba = _mc_error(data, batches=batches, circular=circular)
86
- assert np.allclose(non_numba, with_numba)
87
- assert state == Numba.numba_flag
@@ -1,18 +0,0 @@
1
- import pytest
2
- from _pytest.outcomes import Skipped
3
-
4
- from ..helpers import importorskip
5
-
6
-
7
- def test_importorskip_local(monkeypatch):
8
- """Test ``importorskip`` run on local machine with non-existent module, which should skip."""
9
- monkeypatch.delenv("ARVIZ_REQUIRE_ALL_DEPS", raising=False)
10
- with pytest.raises(Skipped):
11
- importorskip("non-existent-function")
12
-
13
-
14
- def test_importorskip_ci(monkeypatch):
15
- """Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
16
- monkeypatch.setenv("ARVIZ_REQUIRE_ALL_DEPS", 1)
17
- with pytest.raises(ModuleNotFoundError):
18
- importorskip("non-existent-function")
@@ -1,69 +0,0 @@
1
- """Tests for labeller classes."""
2
-
3
- import pytest
4
-
5
- from ...labels import (
6
- BaseLabeller,
7
- DimCoordLabeller,
8
- DimIdxLabeller,
9
- IdxLabeller,
10
- MapLabeller,
11
- NoModelLabeller,
12
- NoVarLabeller,
13
- )
14
-
15
-
16
- class Data:
17
- def __init__(self):
18
- self.sel = {
19
- "instrument": "a",
20
- "experiment": 3,
21
- }
22
- self.isel = {
23
- "instrument": 0,
24
- "experiment": 4,
25
- }
26
-
27
-
28
- @pytest.fixture
29
- def multidim_sels():
30
- return Data()
31
-
32
-
33
- class Labellers:
34
- def __init__(self):
35
- self.labellers = {
36
- "BaseLabeller": BaseLabeller(),
37
- "DimCoordLabeller": DimCoordLabeller(),
38
- "IdxLabeller": IdxLabeller(),
39
- "DimIdxLabeller": DimIdxLabeller(),
40
- "MapLabeller": MapLabeller(),
41
- "NoVarLabeller": NoVarLabeller(),
42
- "NoModelLabeller": NoModelLabeller(),
43
- }
44
-
45
-
46
- @pytest.fixture
47
- def labellers():
48
- return Labellers()
49
-
50
-
51
- @pytest.mark.parametrize(
52
- "args",
53
- [
54
- ("BaseLabeller", "theta\na, 3"),
55
- ("DimCoordLabeller", "theta\ninstrument: a, experiment: 3"),
56
- ("IdxLabeller", "theta\n0, 4"),
57
- ("DimIdxLabeller", "theta\ninstrument#0, experiment#4"),
58
- ("MapLabeller", "theta\na, 3"),
59
- ("NoVarLabeller", "a, 3"),
60
- ("NoModelLabeller", "theta\na, 3"),
61
- ],
62
- )
63
- class TestLabellers:
64
- # pylint: disable=redefined-outer-name
65
- def test_make_label_vert(self, args, multidim_sels, labellers):
66
- name, expected_label = args
67
- labeller_arg = labellers.labellers[name]
68
- label = labeller_arg.make_label_vert("theta", multidim_sels.sel, multidim_sels.isel)
69
- assert label == expected_label
@@ -1,342 +0,0 @@
1
- # pylint: disable=redefined-outer-name
2
- import importlib
3
- import os
4
-
5
- import numpy as np
6
- import pytest
7
- import xarray as xr
8
-
9
- from ...data import from_dict
10
- from ...plots.backends.matplotlib import dealiase_sel_kwargs, matplotlib_kwarg_dealiaser
11
- from ...plots.plot_utils import (
12
- compute_ranks,
13
- filter_plotters_list,
14
- format_sig_figs,
15
- get_plotting_function,
16
- make_2d,
17
- set_bokeh_circular_ticks_labels,
18
- vectorized_to_hex,
19
- )
20
- from ...rcparams import rc_context
21
- from ...sel_utils import xarray_sel_iter, xarray_to_ndarray
22
- from ...stats.density_utils import get_bins
23
- from ...utils import get_coords
24
-
25
- # Check if Bokeh is installed
26
- bokeh_installed = importlib.util.find_spec("bokeh") is not None # pylint: disable=invalid-name
27
- skip_tests = (not bokeh_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
28
-
29
-
30
- @pytest.mark.parametrize(
31
- "value, default, expected",
32
- [
33
- (123.456, 2, 3),
34
- (-123.456, 3, 3),
35
- (-123.456, 4, 4),
36
- (12.3456, 2, 2),
37
- (1.23456, 2, 2),
38
- (0.123456, 2, 2),
39
- ],
40
- )
41
- def test_format_sig_figs(value, default, expected):
42
- assert format_sig_figs(value, default=default) == expected
43
-
44
-
45
- @pytest.fixture(scope="function")
46
- def sample_dataset():
47
- mu = np.arange(1, 7).reshape(2, 3)
48
- tau = np.arange(7, 13).reshape(2, 3)
49
-
50
- chain = [0, 1]
51
- draws = [0, 1, 2]
52
-
53
- data = xr.Dataset(
54
- {"mu": (["chain", "draw"], mu), "tau": (["chain", "draw"], tau)},
55
- coords={"draw": draws, "chain": chain},
56
- )
57
-
58
- return mu, tau, data
59
-
60
-
61
- def test_make_2d():
62
- """Touches code that is hard to reach."""
63
- assert len(make_2d(np.array([2, 3, 4])).shape) == 2
64
-
65
-
66
- def test_get_bins():
67
- """Touches code that is hard to reach."""
68
- assert get_bins(np.array([1, 2, 3, 100])) is not None
69
-
70
-
71
- def test_dataset_to_numpy_not_combined(sample_dataset): # pylint: disable=invalid-name
72
- mu, tau, data = sample_dataset
73
- var_names, data = xarray_to_ndarray(data, combined=False)
74
-
75
- # 2 vars x 2 chains
76
- assert len(var_names) == 4
77
- mu_tau = np.concatenate((mu, tau), axis=0)
78
- tau_mu = np.concatenate((tau, mu), axis=0)
79
- deqmt = data == mu_tau
80
- deqtm = data == tau_mu
81
- assert deqmt.all() or deqtm.all()
82
-
83
-
84
- def test_dataset_to_numpy_combined(sample_dataset):
85
- mu, tau, data = sample_dataset
86
- var_names, data = xarray_to_ndarray(data, combined=True)
87
-
88
- assert len(var_names) == 2
89
- assert (data[var_names.index("mu")] == mu.reshape(1, 6)).all()
90
- assert (data[var_names.index("tau")] == tau.reshape(1, 6)).all()
91
-
92
-
93
- def test_xarray_sel_iter_ordering():
94
- """Assert that coordinate names stay the provided order"""
95
- coords = list("dcba")
96
- data = from_dict( # pylint: disable=no-member
97
- {"x": np.random.randn(1, 100, len(coords))},
98
- coords={"in_order": coords},
99
- dims={"x": ["in_order"]},
100
- ).posterior
101
-
102
- coord_names = [sel["in_order"] for _, sel, _ in xarray_sel_iter(data)]
103
- assert coord_names == coords
104
-
105
-
106
- def test_xarray_sel_iter_ordering_combined(sample_dataset): # pylint: disable=invalid-name
107
- """Assert that varname order stays consistent when chains are combined"""
108
- _, _, data = sample_dataset
109
- var_names = [var for (var, _, _) in xarray_sel_iter(data, var_names=None, combined=True)]
110
- assert set(var_names) == {"mu", "tau"}
111
-
112
-
113
- def test_xarray_sel_iter_ordering_uncombined(sample_dataset): # pylint: disable=invalid-name
114
- """Assert that varname order stays consistent when chains are not combined"""
115
- _, _, data = sample_dataset
116
- var_names = [(var, selection) for (var, selection, _) in xarray_sel_iter(data, var_names=None)]
117
-
118
- assert len(var_names) == 4
119
- for var_name in var_names:
120
- assert var_name in [
121
- ("mu", {"chain": 0}),
122
- ("mu", {"chain": 1}),
123
- ("tau", {"chain": 0}),
124
- ("tau", {"chain": 1}),
125
- ]
126
-
127
-
128
- def test_xarray_sel_data_array(sample_dataset): # pylint: disable=invalid-name
129
- """Assert that varname order stays consistent when chains are combined
130
-
131
- Touches code that is hard to reach.
132
- """
133
- _, _, data = sample_dataset
134
- var_names = [var for (var, _, _) in xarray_sel_iter(data.mu, var_names=None, combined=True)]
135
- assert set(var_names) == {"mu"}
136
-
137
-
138
- class TestCoordsExceptions:
139
- # test coord exceptions on datasets
140
- def test_invalid_coord_name(self, sample_dataset): # pylint: disable=invalid-name
141
- """Assert that nicer exception appears when user enters wrong coords name"""
142
- _, _, data = sample_dataset
143
- coords = {"NOT_A_COORD_NAME": [1]}
144
-
145
- with pytest.raises(
146
- (KeyError, ValueError),
147
- match=(
148
- r"Coords "
149
- r"({'NOT_A_COORD_NAME'} are invalid coordinate keys"
150
- r"|should follow mapping format {coord_name:\[dim1, dim2\]})"
151
- ),
152
- ):
153
- get_coords(data, coords)
154
-
155
- def test_invalid_coord_value(self, sample_dataset): # pylint: disable=invalid-name
156
- """Assert that nicer exception appears when user enters wrong coords value"""
157
- _, _, data = sample_dataset
158
- coords = {"draw": [1234567]}
159
-
160
- with pytest.raises(
161
- KeyError, match=r"Coords should follow mapping format {coord_name:\[dim1, dim2\]}"
162
- ):
163
- get_coords(data, coords)
164
-
165
- def test_invalid_coord_structure(self, sample_dataset): # pylint: disable=invalid-name
166
- """Assert that nicer exception appears when user enters wrong coords datatype"""
167
- _, _, data = sample_dataset
168
- coords = {"draw"}
169
-
170
- with pytest.raises(TypeError):
171
- get_coords(data, coords)
172
-
173
- # test coord exceptions on dataset list
174
- def test_invalid_coord_name_list(self, sample_dataset): # pylint: disable=invalid-name
175
- """Assert that nicer exception appears when user enters wrong coords name"""
176
- _, _, data = sample_dataset
177
- coords = {"NOT_A_COORD_NAME": [1]}
178
-
179
- with pytest.raises(
180
- (KeyError, ValueError),
181
- match=(
182
- r"data\[1\]:.+Coords "
183
- r"({'NOT_A_COORD_NAME'} are invalid coordinate keys"
184
- r"|should follow mapping format {coord_name:\[dim1, dim2\]})"
185
- ),
186
- ):
187
- get_coords((data, data), ({"draw": [0, 1]}, coords))
188
-
189
- def test_invalid_coord_value_list(self, sample_dataset): # pylint: disable=invalid-name
190
- """Assert that nicer exception appears when user enters wrong coords value"""
191
- _, _, data = sample_dataset
192
- coords = {"draw": [1234567]}
193
-
194
- with pytest.raises(
195
- KeyError,
196
- match=r"data\[0\]:.+Coords should follow mapping format {coord_name:\[dim1, dim2\]}",
197
- ):
198
- get_coords((data, data), (coords, {"draw": [0, 1]}))
199
-
200
-
201
- def test_filter_plotter_list():
202
- plotters = list(range(7))
203
- with rc_context({"plot.max_subplots": 10}):
204
- plotters_filtered = filter_plotters_list(plotters, "")
205
- assert plotters == plotters_filtered
206
-
207
-
208
- def test_filter_plotter_list_warning():
209
- plotters = list(range(7))
210
- with rc_context({"plot.max_subplots": 5}):
211
- with pytest.warns(UserWarning, match="test warning"):
212
- plotters_filtered = filter_plotters_list(plotters, "test warning")
213
- assert len(plotters_filtered) == 5
214
-
215
-
216
- @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
217
- def test_bokeh_import():
218
- """Tests that correct method is returned on bokeh import"""
219
- plot = get_plotting_function("plot_dist", "distplot", "bokeh")
220
-
221
- from ...plots.backends.bokeh.distplot import plot_dist
222
-
223
- assert plot is plot_dist
224
-
225
-
226
- @pytest.mark.parametrize(
227
- "params",
228
- [
229
- {
230
- "input": (
231
- {
232
- "dashes": "-",
233
- },
234
- "scatter",
235
- ),
236
- "output": "linestyle",
237
- },
238
- {
239
- "input": (
240
- {"mfc": "blue", "c": "blue", "line_width": 2},
241
- "plot",
242
- ),
243
- "output": ("markerfacecolor", "color", "line_width"),
244
- },
245
- {"input": ({"ec": "blue", "fc": "black"}, "hist"), "output": ("edgecolor", "facecolor")},
246
- {
247
- "input": ({"edgecolors": "blue", "lw": 3}, "hlines"),
248
- "output": ("edgecolor", "linewidth"),
249
- },
250
- ],
251
- )
252
- def test_matplotlib_kwarg_dealiaser(params):
253
- dealiased = matplotlib_kwarg_dealiaser(params["input"][0], kind=params["input"][1])
254
- for returned in dealiased:
255
- assert returned in params["output"]
256
-
257
-
258
- @pytest.mark.parametrize("c_values", ["#0000ff", "blue", [0, 0, 1]])
259
- def test_vectorized_to_hex_scalar(c_values):
260
- output = vectorized_to_hex(c_values)
261
- assert output == "#0000ff"
262
-
263
-
264
- @pytest.mark.parametrize(
265
- "c_values", [["blue", "blue"], ["blue", "#0000ff"], np.array([[0, 0, 1], [0, 0, 1]])]
266
- )
267
- def test_vectorized_to_hex_array(c_values):
268
- output = vectorized_to_hex(c_values)
269
- assert np.all([item == "#0000ff" for item in output])
270
-
271
-
272
- def test_mpl_dealiase_sel_kwargs():
273
- """Check mpl dealiase_sel_kwargs behaviour.
274
-
275
- Makes sure kwargs are overwritten when necessary even with alias involved and that
276
- they are not modified when not included in props.
277
- """
278
- kwargs = {"linewidth": 3, "alpha": 0.4, "line_color": "red"}
279
- props = {"lw": [1, 2, 4, 5], "linestyle": ["-", "--", ":"]}
280
- res = dealiase_sel_kwargs(kwargs, props, 2)
281
- assert "linewidth" in res
282
- assert res["linewidth"] == 4
283
- assert "linestyle" in res
284
- assert res["linestyle"] == ":"
285
- assert "alpha" in res
286
- assert res["alpha"] == 0.4
287
- assert "line_color" in res
288
- assert res["line_color"] == "red"
289
-
290
-
291
- @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
292
- def test_bokeh_dealiase_sel_kwargs():
293
- """Check bokeh dealiase_sel_kwargs behaviour.
294
-
295
- Makes sure kwargs are overwritten when necessary even with alias involved and that
296
- they are not modified when not included in props.
297
- """
298
- from ...plots.backends.bokeh import dealiase_sel_kwargs
299
-
300
- kwargs = {"line_width": 3, "line_alpha": 0.4, "line_color": "red"}
301
- props = {"line_width": [1, 2, 4, 5], "line_dash": ["dashed", "dashed", "dashed"]}
302
- res = dealiase_sel_kwargs(kwargs, props, 2)
303
- assert "line_width" in res
304
- assert res["line_width"] == 4
305
- assert "line_dash" in res
306
- assert res["line_dash"] == "dashed"
307
- assert "line_alpha" in res
308
- assert res["line_alpha"] == 0.4
309
- assert "line_color" in res
310
- assert res["line_color"] == "red"
311
-
312
-
313
- @pytest.mark.skipif(skip_tests, reason="test requires bokeh which is not installed")
314
- def test_set_bokeh_circular_ticks_labels():
315
- """Assert the axes returned after placing ticks and tick labels for circular plots."""
316
- import bokeh.plotting as bkp
317
-
318
- ax = bkp.figure(x_axis_type=None, y_axis_type=None)
319
- hist = np.linspace(0, 1, 10)
320
- labels = ["0°", "45°", "90°", "135°", "180°", "225°", "270°", "315°"]
321
- ax = set_bokeh_circular_ticks_labels(ax, hist, labels)
322
- renderers = ax.renderers
323
- assert len(renderers) == 3
324
- assert renderers[2].data_source.data["text"] == labels
325
- assert len(renderers[0].data_source.data["start_angle"]) == len(labels)
326
-
327
-
328
- def test_compute_ranks():
329
- pois_data = np.array([[5, 4, 1, 4, 0], [2, 8, 2, 1, 1]])
330
- expected = np.array([[9.0, 7.0, 3.0, 8.0, 1.0], [5.0, 10.0, 6.0, 2.0, 4.0]])
331
- ranks = compute_ranks(pois_data)
332
- np.testing.assert_equal(ranks, expected)
333
-
334
- norm_data = np.array(
335
- [
336
- [0.2644187, -1.3004813, -0.80428456, 1.01319068, 0.62631143],
337
- [1.34498018, -0.13428933, -0.69855487, -0.9498981, -0.34074092],
338
- ]
339
- )
340
- expected = np.array([[7.0, 1.0, 3.0, 9.0, 8.0], [10.0, 6.0, 4.0, 2.0, 5.0]])
341
- ranks = compute_ranks(norm_data)
342
- np.testing.assert_equal(ranks, expected)