arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,376 +0,0 @@
|
|
|
1
|
-
"""Tests for arviz.utils."""
|
|
2
|
-
|
|
3
|
-
# pylint: disable=redefined-outer-name, no-member
|
|
4
|
-
from unittest.mock import Mock
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pytest
|
|
8
|
-
import scipy.stats as st
|
|
9
|
-
|
|
10
|
-
from ...data import dict_to_dataset, from_dict, load_arviz_data
|
|
11
|
-
from ...stats.density_utils import _circular_mean, _normalize_angle, _find_hdi_contours
|
|
12
|
-
from ...utils import (
|
|
13
|
-
_stack,
|
|
14
|
-
_subset_list,
|
|
15
|
-
_var_names,
|
|
16
|
-
expand_dims,
|
|
17
|
-
flatten_inference_data_to_dict,
|
|
18
|
-
one_de,
|
|
19
|
-
two_de,
|
|
20
|
-
)
|
|
21
|
-
from ..helpers import RandomVariableTestClass
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
@pytest.fixture(scope="session")
|
|
25
|
-
def inference_data():
|
|
26
|
-
centered_eight = load_arviz_data("centered_eight")
|
|
27
|
-
return centered_eight
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
@pytest.fixture(scope="session")
|
|
31
|
-
def data():
|
|
32
|
-
centered_eight = load_arviz_data("centered_eight")
|
|
33
|
-
return centered_eight.posterior
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@pytest.mark.parametrize(
|
|
37
|
-
"var_names_expected",
|
|
38
|
-
[
|
|
39
|
-
("mu", ["mu"]),
|
|
40
|
-
(None, None),
|
|
41
|
-
(["mu", "tau"], ["mu", "tau"]),
|
|
42
|
-
("~mu", ["theta", "tau"]),
|
|
43
|
-
(["~mu"], ["theta", "tau"]),
|
|
44
|
-
],
|
|
45
|
-
)
|
|
46
|
-
def test_var_names(var_names_expected, data):
|
|
47
|
-
"""Test var_name handling"""
|
|
48
|
-
var_names, expected = var_names_expected
|
|
49
|
-
assert _var_names(var_names, data) == expected
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def test_var_names_warning():
|
|
53
|
-
"""Test confusing var_name handling"""
|
|
54
|
-
data = from_dict(
|
|
55
|
-
posterior={
|
|
56
|
-
"~mu": np.random.randn(2, 10),
|
|
57
|
-
"mu": -np.random.randn(2, 10), # pylint: disable=invalid-unary-operand-type
|
|
58
|
-
"theta": np.random.randn(2, 10, 8),
|
|
59
|
-
}
|
|
60
|
-
).posterior
|
|
61
|
-
var_names = expected = ["~mu"]
|
|
62
|
-
with pytest.warns(UserWarning):
|
|
63
|
-
assert _var_names(var_names, data) == expected
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
def test_var_names_key_error(data):
|
|
67
|
-
with pytest.raises(KeyError, match="bad_var_name"):
|
|
68
|
-
_var_names(("theta", "tau", "bad_var_name"), data)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
@pytest.mark.parametrize(
|
|
72
|
-
"var_args",
|
|
73
|
-
[
|
|
74
|
-
(["ta"], ["beta1", "beta2", "theta"], "like"),
|
|
75
|
-
(["~beta"], ["phi", "theta"], "like"),
|
|
76
|
-
(["beta[0-9]+"], ["beta1", "beta2"], "regex"),
|
|
77
|
-
(["^p"], ["phi"], "regex"),
|
|
78
|
-
(["~^t"], ["beta1", "beta2", "phi"], "regex"),
|
|
79
|
-
],
|
|
80
|
-
)
|
|
81
|
-
def test_var_names_filter_multiple_input(var_args):
|
|
82
|
-
samples = np.random.randn(10)
|
|
83
|
-
data1 = dict_to_dataset({"beta1": samples, "beta2": samples, "phi": samples})
|
|
84
|
-
data2 = dict_to_dataset({"beta1": samples, "beta2": samples, "theta": samples})
|
|
85
|
-
data = [data1, data2]
|
|
86
|
-
var_names, expected, filter_vars = var_args
|
|
87
|
-
assert _var_names(var_names, data, filter_vars) == expected
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
@pytest.mark.parametrize(
|
|
91
|
-
"var_args",
|
|
92
|
-
[
|
|
93
|
-
(["alpha", "beta"], ["alpha", "beta1", "beta2"], "like"),
|
|
94
|
-
(["~beta"], ["alpha", "p1", "p2", "phi", "theta", "theta_t"], "like"),
|
|
95
|
-
(["theta"], ["theta", "theta_t"], "like"),
|
|
96
|
-
(["~theta"], ["alpha", "beta1", "beta2", "p1", "p2", "phi"], "like"),
|
|
97
|
-
(["p"], ["alpha", "p1", "p2", "phi"], "like"),
|
|
98
|
-
(["~p"], ["beta1", "beta2", "theta", "theta_t"], "like"),
|
|
99
|
-
(["^bet"], ["beta1", "beta2"], "regex"),
|
|
100
|
-
(["^p"], ["p1", "p2", "phi"], "regex"),
|
|
101
|
-
(["~^p"], ["alpha", "beta1", "beta2", "theta", "theta_t"], "regex"),
|
|
102
|
-
(["p[0-9]+"], ["p1", "p2"], "regex"),
|
|
103
|
-
(["~p[0-9]+"], ["alpha", "beta1", "beta2", "phi", "theta", "theta_t"], "regex"),
|
|
104
|
-
],
|
|
105
|
-
)
|
|
106
|
-
def test_var_names_filter(var_args):
|
|
107
|
-
"""Test var_names filter with partial naming or regular expressions."""
|
|
108
|
-
samples = np.random.randn(10)
|
|
109
|
-
data = dict_to_dataset(
|
|
110
|
-
{
|
|
111
|
-
"alpha": samples,
|
|
112
|
-
"beta1": samples,
|
|
113
|
-
"beta2": samples,
|
|
114
|
-
"p1": samples,
|
|
115
|
-
"p2": samples,
|
|
116
|
-
"phi": samples,
|
|
117
|
-
"theta": samples,
|
|
118
|
-
"theta_t": samples,
|
|
119
|
-
}
|
|
120
|
-
)
|
|
121
|
-
var_names, expected, filter_vars = var_args
|
|
122
|
-
assert _var_names(var_names, data, filter_vars) == expected
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def test_nonstring_var_names():
|
|
126
|
-
"""Check that non-string variables are preserved"""
|
|
127
|
-
mu = RandomVariableTestClass("mu")
|
|
128
|
-
samples = np.random.randn(10)
|
|
129
|
-
data = dict_to_dataset({mu: samples})
|
|
130
|
-
assert _var_names([mu], data) == [mu]
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def test_var_names_filter_invalid_argument():
|
|
134
|
-
"""Check invalid argument raises."""
|
|
135
|
-
samples = np.random.randn(10)
|
|
136
|
-
data = dict_to_dataset({"alpha": samples})
|
|
137
|
-
msg = r"^\'filter_vars\' can only be None, \'like\', or \'regex\', got: 'foo'$"
|
|
138
|
-
with pytest.raises(ValueError, match=msg):
|
|
139
|
-
assert _var_names(["alpha"], data, filter_vars="foo")
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
def test_subset_list_negation_not_found():
|
|
143
|
-
"""Check there is a warning if negation pattern is ignored"""
|
|
144
|
-
names = ["mu", "theta"]
|
|
145
|
-
with pytest.warns(UserWarning, match=".+not.+found.+"):
|
|
146
|
-
assert _subset_list("~tau", names) == names
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
@pytest.fixture(scope="function")
|
|
150
|
-
def utils_with_numba_import_fail(monkeypatch):
|
|
151
|
-
"""Patch numba in utils so when its imported it raises ImportError"""
|
|
152
|
-
failed_import = Mock()
|
|
153
|
-
failed_import.side_effect = ImportError
|
|
154
|
-
|
|
155
|
-
from ... import utils
|
|
156
|
-
|
|
157
|
-
monkeypatch.setattr(utils.importlib, "import_module", failed_import)
|
|
158
|
-
return utils
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def test_conditional_jit_decorator_no_numba(utils_with_numba_import_fail):
|
|
162
|
-
"""Tests to see if Numba jit code block is skipped with Import Failure
|
|
163
|
-
|
|
164
|
-
Test can be distinguished from test_conditional_jit__numba_decorator
|
|
165
|
-
by use of debugger or coverage tool
|
|
166
|
-
"""
|
|
167
|
-
|
|
168
|
-
@utils_with_numba_import_fail.conditional_jit
|
|
169
|
-
def func():
|
|
170
|
-
return "Numba not used"
|
|
171
|
-
|
|
172
|
-
assert func() == "Numba not used"
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def test_conditional_vect_decorator_no_numba(utils_with_numba_import_fail):
|
|
176
|
-
"""Tests to see if Numba vectorize code block is skipped with Import Failure
|
|
177
|
-
|
|
178
|
-
Test can be distinguished from test_conditional_vect__numba_decorator
|
|
179
|
-
by use of debugger or coverage tool
|
|
180
|
-
"""
|
|
181
|
-
|
|
182
|
-
@utils_with_numba_import_fail.conditional_vect
|
|
183
|
-
def func():
|
|
184
|
-
return "Numba not used"
|
|
185
|
-
|
|
186
|
-
assert func() == "Numba not used"
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
def test_conditional_jit_numba_decorator():
|
|
190
|
-
"""Tests to see if Numba is used.
|
|
191
|
-
|
|
192
|
-
Test can be distinguished from test_conditional_jit_decorator_no_numba
|
|
193
|
-
by use of debugger or coverage tool
|
|
194
|
-
"""
|
|
195
|
-
from ... import utils
|
|
196
|
-
|
|
197
|
-
@utils.conditional_jit
|
|
198
|
-
def func():
|
|
199
|
-
return True
|
|
200
|
-
|
|
201
|
-
assert func()
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
def test_conditional_vect_numba_decorator():
|
|
205
|
-
"""Tests to see if Numba is used.
|
|
206
|
-
|
|
207
|
-
Test can be distinguished from test_conditional_jit_decorator_no_numba
|
|
208
|
-
by use of debugger or coverage tool
|
|
209
|
-
"""
|
|
210
|
-
from ... import utils
|
|
211
|
-
|
|
212
|
-
@utils.conditional_vect
|
|
213
|
-
def func(a_a, b_b):
|
|
214
|
-
return a_a + b_b
|
|
215
|
-
|
|
216
|
-
value_one = np.random.randn(10)
|
|
217
|
-
value_two = np.random.randn(10)
|
|
218
|
-
assert np.allclose(func(value_one, value_two), value_one + value_two)
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
def test_conditional_vect_numba_decorator_keyword(monkeypatch):
|
|
222
|
-
"""Checks else statement and vect keyword argument"""
|
|
223
|
-
from ... import utils
|
|
224
|
-
|
|
225
|
-
# Mock import lib to return numba with hit method which returns a function that returns kwargs
|
|
226
|
-
numba_mock = Mock()
|
|
227
|
-
monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock)
|
|
228
|
-
|
|
229
|
-
def vectorize(**kwargs):
|
|
230
|
-
"""overwrite numba.vectorize function"""
|
|
231
|
-
return lambda x: (x(), kwargs)
|
|
232
|
-
|
|
233
|
-
numba_mock.vectorize = vectorize
|
|
234
|
-
|
|
235
|
-
@utils.conditional_vect(keyword_argument="A keyword argument")
|
|
236
|
-
def placeholder_func():
|
|
237
|
-
"""This function does nothing"""
|
|
238
|
-
return "output"
|
|
239
|
-
|
|
240
|
-
# pylint: disable=unpacking-non-sequence
|
|
241
|
-
function_results, wrapper_result = placeholder_func
|
|
242
|
-
assert wrapper_result == {"keyword_argument": "A keyword argument"}
|
|
243
|
-
assert function_results == "output"
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
def test_stack():
|
|
247
|
-
x = np.random.randn(10, 4, 6)
|
|
248
|
-
y = np.random.randn(100, 4, 6)
|
|
249
|
-
assert x.shape[1:] == y.shape[1:]
|
|
250
|
-
assert np.allclose(np.vstack((x, y)), _stack(x, y))
|
|
251
|
-
assert _stack
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
@pytest.mark.parametrize("data", [np.random.randn(1000), np.random.randn(1000).tolist()])
|
|
255
|
-
def test_two_de(data):
|
|
256
|
-
"""Test to check for custom atleast_2d. List added to test for a non ndarray case."""
|
|
257
|
-
assert np.allclose(two_de(data), np.atleast_2d(data))
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
@pytest.mark.parametrize("data", [np.random.randn(100), np.random.randn(100).tolist()])
|
|
261
|
-
def test_one_de(data):
|
|
262
|
-
"""Test to check for custom atleast_1d. List added to test for a non ndarray case."""
|
|
263
|
-
assert np.allclose(one_de(data), np.atleast_1d(data))
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
@pytest.mark.parametrize("data", [np.random.randn(100), np.random.randn(100).tolist()])
|
|
267
|
-
def test_expand_dims(data):
|
|
268
|
-
"""Test to check for custom expand_dims. List added to test for a non ndarray case."""
|
|
269
|
-
assert np.allclose(expand_dims(data), np.expand_dims(data, 0))
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
@pytest.mark.parametrize("var_names", [None, "mu", ["mu", "tau"]])
|
|
273
|
-
@pytest.mark.parametrize(
|
|
274
|
-
"groups", [None, "posterior_groups", "prior_groups", ["posterior", "sample_stats"]]
|
|
275
|
-
)
|
|
276
|
-
@pytest.mark.parametrize("dimensions", [None, "draw", ["chain", "draw"]])
|
|
277
|
-
@pytest.mark.parametrize("group_info", [True, False])
|
|
278
|
-
@pytest.mark.parametrize(
|
|
279
|
-
"var_name_format", [None, "brackets", "underscore", "cds", ((",", "[", "]"), ("_", ""))]
|
|
280
|
-
)
|
|
281
|
-
@pytest.mark.parametrize("index_origin", [None, 0, 1])
|
|
282
|
-
def test_flatten_inference_data_to_dict(
|
|
283
|
-
inference_data, var_names, groups, dimensions, group_info, var_name_format, index_origin
|
|
284
|
-
):
|
|
285
|
-
"""Test flattening (stacking) inference data (subgroups) for dictionary."""
|
|
286
|
-
res_dict = flatten_inference_data_to_dict(
|
|
287
|
-
data=inference_data,
|
|
288
|
-
var_names=var_names,
|
|
289
|
-
groups=groups,
|
|
290
|
-
dimensions=dimensions,
|
|
291
|
-
group_info=group_info,
|
|
292
|
-
var_name_format=var_name_format,
|
|
293
|
-
index_origin=index_origin,
|
|
294
|
-
)
|
|
295
|
-
assert res_dict
|
|
296
|
-
assert "draw" in res_dict
|
|
297
|
-
assert any("mu" in item for item in res_dict)
|
|
298
|
-
if group_info:
|
|
299
|
-
if groups != "prior_groups":
|
|
300
|
-
assert any("posterior" in item for item in res_dict)
|
|
301
|
-
if var_names is None:
|
|
302
|
-
assert any("sample_stats" in item for item in res_dict)
|
|
303
|
-
else:
|
|
304
|
-
assert any("prior" in item for item in res_dict)
|
|
305
|
-
elif groups == "prior_groups":
|
|
306
|
-
assert all("prior" not in item for item in res_dict)
|
|
307
|
-
|
|
308
|
-
else:
|
|
309
|
-
assert all("posterior" not in item for item in res_dict)
|
|
310
|
-
if var_names is None:
|
|
311
|
-
assert all("sample_stats" not in item for item in res_dict)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
@pytest.mark.parametrize("mean", [0, np.pi, 4 * np.pi, -2 * np.pi, -10 * np.pi])
|
|
315
|
-
def test_circular_mean_scipy(mean):
|
|
316
|
-
"""Test our `_circular_mean()` function gives same result than Scipy version."""
|
|
317
|
-
rvs = st.vonmises.rvs(loc=mean, kappa=1, size=1000)
|
|
318
|
-
mean_az = _circular_mean(rvs)
|
|
319
|
-
mean_sp = st.circmean(rvs, low=-np.pi, high=np.pi)
|
|
320
|
-
np.testing.assert_almost_equal(mean_az, mean_sp)
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
@pytest.mark.parametrize("mean", [0, np.pi, 4 * np.pi, -2 * np.pi, -10 * np.pi])
|
|
324
|
-
def test_normalize_angle(mean):
|
|
325
|
-
"""Testing _normalize_angles() return values between expected bounds"""
|
|
326
|
-
rvs = st.vonmises.rvs(loc=mean, kappa=1, size=1000)
|
|
327
|
-
values = _normalize_angle(rvs, zero_centered=True)
|
|
328
|
-
assert ((-np.pi <= values) & (values <= np.pi)).all()
|
|
329
|
-
|
|
330
|
-
values = _normalize_angle(rvs, zero_centered=False)
|
|
331
|
-
assert ((values >= 0) & (values <= 2 * np.pi)).all()
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
@pytest.mark.parametrize("mean", [[0, 0], [1, 1]])
|
|
335
|
-
@pytest.mark.parametrize(
|
|
336
|
-
"cov",
|
|
337
|
-
[
|
|
338
|
-
np.diag([1, 1]),
|
|
339
|
-
np.diag([0.5, 0.5]),
|
|
340
|
-
np.diag([0.25, 1]),
|
|
341
|
-
np.array([[0.4, 0.2], [0.2, 0.8]]),
|
|
342
|
-
],
|
|
343
|
-
)
|
|
344
|
-
@pytest.mark.parametrize("contour_sigma", [np.array([1, 2, 3])])
|
|
345
|
-
def test_find_hdi_contours(mean, cov, contour_sigma):
|
|
346
|
-
"""Test `_find_hdi_contours()` against SciPy's multivariate normal distribution."""
|
|
347
|
-
# Set up scipy distribution
|
|
348
|
-
prob_dist = st.multivariate_normal(mean, cov)
|
|
349
|
-
|
|
350
|
-
# Find standard deviations and eigenvectors
|
|
351
|
-
eigenvals, eigenvecs = np.linalg.eig(cov)
|
|
352
|
-
eigenvecs = eigenvecs.T
|
|
353
|
-
stdevs = np.sqrt(eigenvals)
|
|
354
|
-
|
|
355
|
-
# Find min and max for grid at 7-sigma contour
|
|
356
|
-
extremes = np.empty((4, 2))
|
|
357
|
-
for i in range(4):
|
|
358
|
-
extremes[i] = mean + (-1) ** i * 7 * stdevs[i // 2] * eigenvecs[i // 2]
|
|
359
|
-
x_min, y_min = np.amin(extremes, axis=0)
|
|
360
|
-
x_max, y_max = np.amax(extremes, axis=0)
|
|
361
|
-
|
|
362
|
-
# Create 256x256 grid
|
|
363
|
-
x = np.linspace(x_min, x_max, 256)
|
|
364
|
-
y = np.linspace(y_min, y_max, 256)
|
|
365
|
-
grid = np.dstack(np.meshgrid(x, y))
|
|
366
|
-
|
|
367
|
-
density = prob_dist.pdf(grid)
|
|
368
|
-
|
|
369
|
-
contour_sp = np.empty(contour_sigma.shape)
|
|
370
|
-
for idx, sigma in enumerate(contour_sigma):
|
|
371
|
-
contour_sp[idx] = prob_dist.pdf(mean + sigma * stdevs[0] * eigenvecs[0])
|
|
372
|
-
|
|
373
|
-
hdi_probs = 1 - np.exp(-0.5 * contour_sigma**2)
|
|
374
|
-
contour_az = _find_hdi_contours(density, hdi_probs)
|
|
375
|
-
|
|
376
|
-
np.testing.assert_allclose(contour_sp, contour_az, rtol=1e-2, atol=1e-4)
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name, no-member
|
|
2
|
-
"""Tests for arviz.utils."""
|
|
3
|
-
import importlib
|
|
4
|
-
from unittest.mock import Mock
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pytest
|
|
8
|
-
|
|
9
|
-
from ...stats.stats_utils import stats_variance_2d as svar
|
|
10
|
-
from ...utils import Numba, _numba_var, numba_check
|
|
11
|
-
from ..helpers import importorskip
|
|
12
|
-
from .test_utils import utils_with_numba_import_fail # pylint: disable=unused-import
|
|
13
|
-
|
|
14
|
-
importorskip("numba")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def test_utils_fixture(utils_with_numba_import_fail):
|
|
18
|
-
"""Test of utils fixture to ensure mock is applied correctly"""
|
|
19
|
-
|
|
20
|
-
# If Numba doesn't exist in dev environment this will raise an ImportError
|
|
21
|
-
import numba # pylint: disable=unused-import,W0612
|
|
22
|
-
|
|
23
|
-
with pytest.raises(ImportError):
|
|
24
|
-
utils_with_numba_import_fail.importlib.import_module("numba")
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def test_conditional_jit_numba_decorator_keyword(monkeypatch):
|
|
28
|
-
"""Checks else statement and JIT keyword argument"""
|
|
29
|
-
from ... import utils
|
|
30
|
-
|
|
31
|
-
# Mock import lib to return numba with hit method which returns a function that returns kwargs
|
|
32
|
-
numba_mock = Mock()
|
|
33
|
-
monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock)
|
|
34
|
-
|
|
35
|
-
def jit(**kwargs):
|
|
36
|
-
"""overwrite numba.jit function"""
|
|
37
|
-
return lambda fn: lambda: (fn(), kwargs)
|
|
38
|
-
|
|
39
|
-
numba_mock.jit = jit
|
|
40
|
-
|
|
41
|
-
@utils.conditional_jit(keyword_argument="A keyword argument")
|
|
42
|
-
def placeholder_func():
|
|
43
|
-
"""This function does nothing"""
|
|
44
|
-
return "output"
|
|
45
|
-
|
|
46
|
-
# pylint: disable=unpacking-non-sequence
|
|
47
|
-
function_results, wrapper_result = placeholder_func()
|
|
48
|
-
assert wrapper_result == {"keyword_argument": "A keyword argument", "nopython": True}
|
|
49
|
-
assert function_results == "output"
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def test_numba_check():
|
|
53
|
-
"""Test for numba_check"""
|
|
54
|
-
numba = importlib.util.find_spec("numba")
|
|
55
|
-
flag = numba is not None
|
|
56
|
-
assert flag == numba_check()
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def test_numba_utils():
|
|
60
|
-
"""Test for class Numba."""
|
|
61
|
-
flag = Numba.numba_flag
|
|
62
|
-
assert flag == numba_check()
|
|
63
|
-
Numba.disable_numba()
|
|
64
|
-
val = Numba.numba_flag
|
|
65
|
-
assert not val
|
|
66
|
-
Numba.enable_numba()
|
|
67
|
-
val = Numba.numba_flag
|
|
68
|
-
assert val
|
|
69
|
-
assert flag == Numba.numba_flag
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
@pytest.mark.parametrize("axis", (0, 1))
|
|
73
|
-
@pytest.mark.parametrize("ddof", (0, 1))
|
|
74
|
-
def test_numba_var(axis, ddof):
|
|
75
|
-
"""Method to test numba_var."""
|
|
76
|
-
flag = Numba.numba_flag
|
|
77
|
-
data_1 = np.random.randn(100, 100)
|
|
78
|
-
data_2 = np.random.rand(100)
|
|
79
|
-
with_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof)
|
|
80
|
-
with_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof)
|
|
81
|
-
Numba.disable_numba()
|
|
82
|
-
non_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof)
|
|
83
|
-
non_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof)
|
|
84
|
-
Numba.enable_numba()
|
|
85
|
-
assert flag == Numba.numba_flag
|
|
86
|
-
assert np.allclose(with_numba_1, non_numba_1)
|
|
87
|
-
assert np.allclose(with_numba_2, non_numba_2)
|
arviz/tests/conftest.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name
|
|
2
|
-
"""Configuration for test suite."""
|
|
3
|
-
import logging
|
|
4
|
-
import os
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import pytest
|
|
8
|
-
|
|
9
|
-
_log = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@pytest.fixture(autouse=True)
|
|
13
|
-
def random_seed():
|
|
14
|
-
"""Reset numpy random seed generator."""
|
|
15
|
-
np.random.seed(0)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def pytest_addoption(parser):
|
|
19
|
-
"""Definition for command line option to save figures from tests."""
|
|
20
|
-
parser.addoption("--save", nargs="?", const="test_images", help="Save images rendered by plot")
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
@pytest.fixture(scope="session")
|
|
24
|
-
def save_figs(request):
|
|
25
|
-
"""Enable command line switch for saving generation figures upon testing."""
|
|
26
|
-
fig_dir = request.config.getoption("--save")
|
|
27
|
-
|
|
28
|
-
if fig_dir is not None:
|
|
29
|
-
# Try creating directory if it doesn't exist
|
|
30
|
-
_log.info("Saving generated images in %s", fig_dir)
|
|
31
|
-
|
|
32
|
-
os.makedirs(fig_dir, exist_ok=True)
|
|
33
|
-
_log.info("Directory %s created", fig_dir)
|
|
34
|
-
|
|
35
|
-
# Clear all files from the directory
|
|
36
|
-
# Does not alter or delete directories
|
|
37
|
-
for file in os.listdir(fig_dir):
|
|
38
|
-
full_path = os.path.join(fig_dir, file)
|
|
39
|
-
|
|
40
|
-
try:
|
|
41
|
-
os.remove(full_path)
|
|
42
|
-
|
|
43
|
-
except OSError:
|
|
44
|
-
_log.info("Failed to remove %s", full_path)
|
|
45
|
-
|
|
46
|
-
return fig_dir
|
|
@@ -1 +0,0 @@
|
|
|
1
|
-
"""Backend test suite."""
|
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
# pylint: disable=no-member, invalid-name, redefined-outer-name
|
|
2
|
-
import numpy as np
|
|
3
|
-
import pytest
|
|
4
|
-
|
|
5
|
-
from ...data.io_beanmachine import from_beanmachine # pylint: disable=wrong-import-position
|
|
6
|
-
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
|
|
7
|
-
chains,
|
|
8
|
-
draws,
|
|
9
|
-
eight_schools_params,
|
|
10
|
-
importorskip,
|
|
11
|
-
load_cached_models,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
pytest.skip("Ignore beanmachine tests until it supports pytorch 2", allow_module_level=True)
|
|
15
|
-
|
|
16
|
-
# Skip all tests if beanmachine or pytorch not installed
|
|
17
|
-
torch = importorskip("torch")
|
|
18
|
-
bm = importorskip("beanmachine.ppl")
|
|
19
|
-
dist = torch.distributions
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class TestDataBeanMachine:
|
|
23
|
-
@pytest.fixture(scope="class")
|
|
24
|
-
def data(self, eight_schools_params, draws, chains):
|
|
25
|
-
class Data:
|
|
26
|
-
model, prior, obj = load_cached_models(
|
|
27
|
-
eight_schools_params,
|
|
28
|
-
draws,
|
|
29
|
-
chains,
|
|
30
|
-
"beanmachine",
|
|
31
|
-
)["beanmachine"]
|
|
32
|
-
|
|
33
|
-
return Data
|
|
34
|
-
|
|
35
|
-
@pytest.fixture(scope="class")
|
|
36
|
-
def predictions_data(self, data):
|
|
37
|
-
"""Generate predictions for predictions_params"""
|
|
38
|
-
posterior_samples = data.obj
|
|
39
|
-
model = data.model
|
|
40
|
-
predictions = bm.inference.predictive.simulate([model.obs()], posterior_samples)
|
|
41
|
-
return predictions
|
|
42
|
-
|
|
43
|
-
def get_inference_data(self, eight_schools_params, predictions_data):
|
|
44
|
-
predictions = predictions_data
|
|
45
|
-
return from_beanmachine(
|
|
46
|
-
sampler=predictions,
|
|
47
|
-
coords={
|
|
48
|
-
"school": np.arange(eight_schools_params["J"]),
|
|
49
|
-
"school_pred": np.arange(eight_schools_params["J"]),
|
|
50
|
-
},
|
|
51
|
-
)
|
|
52
|
-
|
|
53
|
-
def test_inference_data(self, data, eight_schools_params, predictions_data):
|
|
54
|
-
inference_data = self.get_inference_data(eight_schools_params, predictions_data)
|
|
55
|
-
model = data.model
|
|
56
|
-
mu = model.mu()
|
|
57
|
-
tau = model.tau()
|
|
58
|
-
eta = model.eta()
|
|
59
|
-
obs = model.obs()
|
|
60
|
-
|
|
61
|
-
assert mu in inference_data.posterior
|
|
62
|
-
assert tau in inference_data.posterior
|
|
63
|
-
assert eta in inference_data.posterior
|
|
64
|
-
assert obs in inference_data.posterior_predictive
|
|
65
|
-
|
|
66
|
-
def test_inference_data_has_log_likelihood_and_observed_data(self, data):
|
|
67
|
-
idata = from_beanmachine(data.obj)
|
|
68
|
-
obs = data.model.obs()
|
|
69
|
-
|
|
70
|
-
assert obs in idata.log_likelihood
|
|
71
|
-
assert obs in idata.observed_data
|
|
72
|
-
|
|
73
|
-
def test_inference_data_no_posterior(self, data):
|
|
74
|
-
model = data.model
|
|
75
|
-
# only prior
|
|
76
|
-
inference_data = from_beanmachine(data.prior)
|
|
77
|
-
assert not model.obs() in inference_data.posterior
|
|
78
|
-
assert "observed_data" not in inference_data
|