arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
|
@@ -1,143 +0,0 @@
|
|
|
1
|
-
# pylint: disable=redefined-outer-name
|
|
2
|
-
import os
|
|
3
|
-
from collections.abc import MutableMapping
|
|
4
|
-
from tempfile import TemporaryDirectory
|
|
5
|
-
from typing import Mapping
|
|
6
|
-
|
|
7
|
-
import numpy as np
|
|
8
|
-
import pytest
|
|
9
|
-
|
|
10
|
-
from ... import InferenceData, from_dict
|
|
11
|
-
from ... import to_zarr, from_zarr
|
|
12
|
-
|
|
13
|
-
from ..helpers import ( # pylint: disable=unused-import
|
|
14
|
-
chains,
|
|
15
|
-
check_multiple_attrs,
|
|
16
|
-
draws,
|
|
17
|
-
eight_schools_params,
|
|
18
|
-
importorskip,
|
|
19
|
-
)
|
|
20
|
-
|
|
21
|
-
zarr = importorskip("zarr") # pylint: disable=invalid-name
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class TestDataZarr:
|
|
25
|
-
@pytest.fixture(scope="class")
|
|
26
|
-
def data(self, draws, chains):
|
|
27
|
-
class Data:
|
|
28
|
-
# fake 8-school output
|
|
29
|
-
shapes: Mapping[str, list] = {"mu": [], "tau": [], "eta": [8], "theta": [8]}
|
|
30
|
-
obj = {key: np.random.randn(chains, draws, *shape) for key, shape in shapes.items()}
|
|
31
|
-
|
|
32
|
-
return Data
|
|
33
|
-
|
|
34
|
-
def get_inference_data(self, data, eight_schools_params, fill_attrs):
|
|
35
|
-
return from_dict(
|
|
36
|
-
posterior=data.obj,
|
|
37
|
-
posterior_predictive=data.obj,
|
|
38
|
-
sample_stats=data.obj,
|
|
39
|
-
prior=data.obj,
|
|
40
|
-
prior_predictive=data.obj,
|
|
41
|
-
sample_stats_prior=data.obj,
|
|
42
|
-
observed_data=eight_schools_params,
|
|
43
|
-
coords={"school": np.arange(8)},
|
|
44
|
-
dims={"theta": ["school"], "eta": ["school"]},
|
|
45
|
-
attrs={"test": 1} if fill_attrs else None,
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
@pytest.mark.parametrize("store", [0, 1, 2])
|
|
49
|
-
@pytest.mark.parametrize("fill_attrs", [True, False])
|
|
50
|
-
def test_io_method(self, data, eight_schools_params, store, fill_attrs):
|
|
51
|
-
# create InferenceData and check it has been properly created
|
|
52
|
-
inference_data = self.get_inference_data( # pylint: disable=W0612
|
|
53
|
-
data, eight_schools_params, fill_attrs
|
|
54
|
-
)
|
|
55
|
-
test_dict = {
|
|
56
|
-
"posterior": ["eta", "theta", "mu", "tau"],
|
|
57
|
-
"posterior_predictive": ["eta", "theta", "mu", "tau"],
|
|
58
|
-
"sample_stats": ["eta", "theta", "mu", "tau"],
|
|
59
|
-
"prior": ["eta", "theta", "mu", "tau"],
|
|
60
|
-
"prior_predictive": ["eta", "theta", "mu", "tau"],
|
|
61
|
-
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
|
|
62
|
-
"observed_data": ["J", "y", "sigma"],
|
|
63
|
-
}
|
|
64
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
65
|
-
assert not fails
|
|
66
|
-
|
|
67
|
-
if fill_attrs:
|
|
68
|
-
assert inference_data.attrs["test"] == 1
|
|
69
|
-
else:
|
|
70
|
-
assert "test" not in inference_data.attrs
|
|
71
|
-
|
|
72
|
-
# check filename does not exist and use to_zarr method
|
|
73
|
-
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
|
|
74
|
-
filepath = os.path.join(tmp_dir, "zarr")
|
|
75
|
-
|
|
76
|
-
# InferenceData method
|
|
77
|
-
if store == 0:
|
|
78
|
-
# Tempdir
|
|
79
|
-
store = inference_data.to_zarr(store=None)
|
|
80
|
-
assert isinstance(store, MutableMapping)
|
|
81
|
-
elif store == 1:
|
|
82
|
-
inference_data.to_zarr(store=filepath)
|
|
83
|
-
# assert file has been saved correctly
|
|
84
|
-
assert os.path.exists(filepath)
|
|
85
|
-
assert os.path.getsize(filepath) > 0
|
|
86
|
-
elif store == 2:
|
|
87
|
-
store = zarr.storage.DirectoryStore(filepath)
|
|
88
|
-
inference_data.to_zarr(store=store)
|
|
89
|
-
# assert file has been saved correctly
|
|
90
|
-
assert os.path.exists(filepath)
|
|
91
|
-
assert os.path.getsize(filepath) > 0
|
|
92
|
-
|
|
93
|
-
if isinstance(store, MutableMapping):
|
|
94
|
-
inference_data2 = InferenceData.from_zarr(store)
|
|
95
|
-
else:
|
|
96
|
-
inference_data2 = InferenceData.from_zarr(filepath)
|
|
97
|
-
|
|
98
|
-
# Everything in dict still available in inference_data2 ?
|
|
99
|
-
fails = check_multiple_attrs(test_dict, inference_data2)
|
|
100
|
-
assert not fails
|
|
101
|
-
|
|
102
|
-
if fill_attrs:
|
|
103
|
-
assert inference_data2.attrs["test"] == 1
|
|
104
|
-
else:
|
|
105
|
-
assert "test" not in inference_data2.attrs
|
|
106
|
-
|
|
107
|
-
def test_io_function(self, data, eight_schools_params):
|
|
108
|
-
# create InferenceData and check it has been properly created
|
|
109
|
-
inference_data = self.get_inference_data( # pylint: disable=W0612
|
|
110
|
-
data,
|
|
111
|
-
eight_schools_params,
|
|
112
|
-
fill_attrs=True,
|
|
113
|
-
)
|
|
114
|
-
test_dict = {
|
|
115
|
-
"posterior": ["eta", "theta", "mu", "tau"],
|
|
116
|
-
"posterior_predictive": ["eta", "theta", "mu", "tau"],
|
|
117
|
-
"sample_stats": ["eta", "theta", "mu", "tau"],
|
|
118
|
-
"prior": ["eta", "theta", "mu", "tau"],
|
|
119
|
-
"prior_predictive": ["eta", "theta", "mu", "tau"],
|
|
120
|
-
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
|
|
121
|
-
"observed_data": ["J", "y", "sigma"],
|
|
122
|
-
}
|
|
123
|
-
fails = check_multiple_attrs(test_dict, inference_data)
|
|
124
|
-
assert not fails
|
|
125
|
-
|
|
126
|
-
assert inference_data.attrs["test"] == 1
|
|
127
|
-
|
|
128
|
-
# check filename does not exist and use to_zarr method
|
|
129
|
-
with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
|
|
130
|
-
filepath = os.path.join(tmp_dir, "zarr")
|
|
131
|
-
|
|
132
|
-
to_zarr(inference_data, store=filepath)
|
|
133
|
-
# assert file has been saved correctly
|
|
134
|
-
assert os.path.exists(filepath)
|
|
135
|
-
assert os.path.getsize(filepath) > 0
|
|
136
|
-
|
|
137
|
-
inference_data2 = from_zarr(filepath)
|
|
138
|
-
|
|
139
|
-
# Everything in dict still available in inference_data2 ?
|
|
140
|
-
fails = check_multiple_attrs(test_dict, inference_data2)
|
|
141
|
-
assert not fails
|
|
142
|
-
|
|
143
|
-
assert inference_data2.attrs["test"] == 1
|
|
@@ -1,511 +0,0 @@
|
|
|
1
|
-
"""Test Diagnostic methods"""
|
|
2
|
-
|
|
3
|
-
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
|
|
4
|
-
import os
|
|
5
|
-
|
|
6
|
-
import numpy as np
|
|
7
|
-
import packaging
|
|
8
|
-
import pandas as pd
|
|
9
|
-
import pytest
|
|
10
|
-
import scipy
|
|
11
|
-
from numpy.testing import assert_almost_equal
|
|
12
|
-
|
|
13
|
-
from ...data import from_cmdstan, load_arviz_data
|
|
14
|
-
from ...rcparams import rcParams
|
|
15
|
-
from ...sel_utils import xarray_var_iter
|
|
16
|
-
from ...stats import bfmi, ess, mcse, rhat
|
|
17
|
-
from ...stats.diagnostics import (
|
|
18
|
-
_ess,
|
|
19
|
-
_ess_quantile,
|
|
20
|
-
_mc_error,
|
|
21
|
-
_mcse_quantile,
|
|
22
|
-
_multichain_statistics,
|
|
23
|
-
_rhat,
|
|
24
|
-
_rhat_rank,
|
|
25
|
-
_split_chains,
|
|
26
|
-
_z_scale,
|
|
27
|
-
ks_summary,
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
# For tests only, recommended value should be closer to 1.01-1.05
|
|
31
|
-
# See discussion in https://github.com/stan-dev/rstan/pull/618
|
|
32
|
-
GOOD_RHAT = 1.1
|
|
33
|
-
|
|
34
|
-
rcParams["data.load"] = "eager"
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@pytest.fixture(scope="session")
|
|
38
|
-
def data():
|
|
39
|
-
centered_eight = load_arviz_data("centered_eight")
|
|
40
|
-
return centered_eight.posterior
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
class TestDiagnostics:
|
|
44
|
-
def test_bfmi(self):
|
|
45
|
-
energy = np.array([1, 2, 3, 4])
|
|
46
|
-
assert_almost_equal(bfmi(energy), 0.6)
|
|
47
|
-
|
|
48
|
-
def test_bfmi_dataset(self):
|
|
49
|
-
data = load_arviz_data("centered_eight")
|
|
50
|
-
assert bfmi(data).all()
|
|
51
|
-
|
|
52
|
-
def test_bfmi_dataset_bad(self):
|
|
53
|
-
data = load_arviz_data("centered_eight")
|
|
54
|
-
del data.sample_stats["energy"]
|
|
55
|
-
with pytest.raises(TypeError):
|
|
56
|
-
bfmi(data)
|
|
57
|
-
|
|
58
|
-
def test_bfmi_correctly_transposed(self):
|
|
59
|
-
data = load_arviz_data("centered_eight")
|
|
60
|
-
vals1 = bfmi(data)
|
|
61
|
-
data.sample_stats["energy"] = data.sample_stats["energy"].T
|
|
62
|
-
vals2 = bfmi(data)
|
|
63
|
-
assert_almost_equal(vals1, vals2)
|
|
64
|
-
|
|
65
|
-
def test_deterministic(self):
|
|
66
|
-
"""
|
|
67
|
-
Test algorithm against posterior (R) convergence functions.
|
|
68
|
-
|
|
69
|
-
posterior: https://github.com/stan-dev/posterior
|
|
70
|
-
R code:
|
|
71
|
-
```
|
|
72
|
-
library("posterior")
|
|
73
|
-
data2 <- read.csv("blocker.2.csv", comment.char = "#")
|
|
74
|
-
data1 <- read.csv("blocker.1.csv", comment.char = "#")
|
|
75
|
-
output <- matrix(ncol=17, nrow=length(names(data1))-4)
|
|
76
|
-
j = 0
|
|
77
|
-
for (i in 1:length(names(data1))) {
|
|
78
|
-
name = names(data1)[i]
|
|
79
|
-
ary = matrix(c(data1[,name], data2[,name]), 1000, 2)
|
|
80
|
-
if (!endsWith(name, "__"))
|
|
81
|
-
j <- j + 1
|
|
82
|
-
output[j,] <- c(
|
|
83
|
-
posterior::rhat(ary),
|
|
84
|
-
posterior::rhat_basic(ary, FALSE),
|
|
85
|
-
posterior::ess_bulk(ary),
|
|
86
|
-
posterior::ess_tail(ary),
|
|
87
|
-
posterior::ess_mean(ary),
|
|
88
|
-
posterior::ess_sd(ary),
|
|
89
|
-
posterior::ess_median(ary),
|
|
90
|
-
posterior::ess_basic(ary, FALSE),
|
|
91
|
-
posterior::ess_quantile(ary, 0.01),
|
|
92
|
-
posterior::ess_quantile(ary, 0.1),
|
|
93
|
-
posterior::ess_quantile(ary, 0.3),
|
|
94
|
-
posterior::mcse_mean(ary),
|
|
95
|
-
posterior::mcse_sd(ary),
|
|
96
|
-
posterior::mcse_median(ary),
|
|
97
|
-
posterior::mcse_quantile(ary, prob=0.01),
|
|
98
|
-
posterior::mcse_quantile(ary, prob=0.1),
|
|
99
|
-
posterior::mcse_quantile(ary, prob=0.3))
|
|
100
|
-
}
|
|
101
|
-
df = data.frame(output, row.names = names(data1)[5:ncol(data1)])
|
|
102
|
-
colnames(df) <- c("rhat_rank",
|
|
103
|
-
"rhat_raw",
|
|
104
|
-
"ess_bulk",
|
|
105
|
-
"ess_tail",
|
|
106
|
-
"ess_mean",
|
|
107
|
-
"ess_sd",
|
|
108
|
-
"ess_median",
|
|
109
|
-
"ess_raw",
|
|
110
|
-
"ess_quantile01",
|
|
111
|
-
"ess_quantile10",
|
|
112
|
-
"ess_quantile30",
|
|
113
|
-
"mcse_mean",
|
|
114
|
-
"mcse_sd",
|
|
115
|
-
"mcse_median",
|
|
116
|
-
"mcse_quantile01",
|
|
117
|
-
"mcse_quantile10",
|
|
118
|
-
"mcse_quantile30")
|
|
119
|
-
write.csv(df, "reference_posterior.csv")
|
|
120
|
-
```
|
|
121
|
-
Reference file:
|
|
122
|
-
|
|
123
|
-
Created: 2024-12-20
|
|
124
|
-
System: Ubuntu 24.04.1 LTS
|
|
125
|
-
R version 4.4.2 (2024-10-31)
|
|
126
|
-
posterior version from https://github.com/stan-dev/posterior/pull/388
|
|
127
|
-
(after release 1.6.0 but before the fixes in the PR were released).
|
|
128
|
-
"""
|
|
129
|
-
# download input files
|
|
130
|
-
here = os.path.dirname(os.path.abspath(__file__))
|
|
131
|
-
data_directory = os.path.join(here, "..", "saved_models")
|
|
132
|
-
path = os.path.join(data_directory, "stan_diagnostics", "blocker.[0-9].csv")
|
|
133
|
-
posterior = from_cmdstan(path)
|
|
134
|
-
reference_path = os.path.join(data_directory, "stan_diagnostics", "reference_posterior.csv")
|
|
135
|
-
reference = (
|
|
136
|
-
pd.read_csv(reference_path, index_col=0, float_precision="high")
|
|
137
|
-
.sort_index(axis=1)
|
|
138
|
-
.sort_index(axis=0)
|
|
139
|
-
)
|
|
140
|
-
# test arviz functions
|
|
141
|
-
funcs = {
|
|
142
|
-
"rhat_rank": lambda x: rhat(x, method="rank"),
|
|
143
|
-
"rhat_raw": lambda x: rhat(x, method="identity"),
|
|
144
|
-
"ess_bulk": lambda x: ess(x, method="bulk"),
|
|
145
|
-
"ess_tail": lambda x: ess(x, method="tail"),
|
|
146
|
-
"ess_mean": lambda x: ess(x, method="mean"),
|
|
147
|
-
"ess_sd": lambda x: ess(x, method="sd"),
|
|
148
|
-
"ess_median": lambda x: ess(x, method="median"),
|
|
149
|
-
"ess_raw": lambda x: ess(x, method="identity"),
|
|
150
|
-
"ess_quantile01": lambda x: ess(x, method="quantile", prob=0.01),
|
|
151
|
-
"ess_quantile10": lambda x: ess(x, method="quantile", prob=0.1),
|
|
152
|
-
"ess_quantile30": lambda x: ess(x, method="quantile", prob=0.3),
|
|
153
|
-
"mcse_mean": lambda x: mcse(x, method="mean"),
|
|
154
|
-
"mcse_sd": lambda x: mcse(x, method="sd"),
|
|
155
|
-
"mcse_median": lambda x: mcse(x, method="median"),
|
|
156
|
-
"mcse_quantile01": lambda x: mcse(x, method="quantile", prob=0.01),
|
|
157
|
-
"mcse_quantile10": lambda x: mcse(x, method="quantile", prob=0.1),
|
|
158
|
-
"mcse_quantile30": lambda x: mcse(x, method="quantile", prob=0.3),
|
|
159
|
-
}
|
|
160
|
-
results = {}
|
|
161
|
-
for key, coord_dict, _, vals in xarray_var_iter(posterior.posterior, combined=True):
|
|
162
|
-
if coord_dict:
|
|
163
|
-
key = f"{key}.{list(coord_dict.values())[0] + 1}"
|
|
164
|
-
results[key] = {func_name: func(vals) for func_name, func in funcs.items()}
|
|
165
|
-
arviz_data = pd.DataFrame.from_dict(results).T.sort_index(axis=1).sort_index(axis=0)
|
|
166
|
-
|
|
167
|
-
# check column names
|
|
168
|
-
assert set(arviz_data.columns) == set(reference.columns)
|
|
169
|
-
|
|
170
|
-
# check parameter names
|
|
171
|
-
assert set(arviz_data.index) == set(reference.index)
|
|
172
|
-
|
|
173
|
-
# show print with pytests '-s' tag
|
|
174
|
-
np.set_printoptions(16)
|
|
175
|
-
print(abs(reference - arviz_data).max())
|
|
176
|
-
|
|
177
|
-
# test absolute accuracy
|
|
178
|
-
assert (abs(reference - arviz_data).values < 1e-8).all(None)
|
|
179
|
-
|
|
180
|
-
@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
|
|
181
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
|
|
182
|
-
def test_rhat(self, data, var_names, method):
|
|
183
|
-
"""Confirm R-hat statistic is close to 1 for a large
|
|
184
|
-
number of samples. Also checks the correct shape"""
|
|
185
|
-
rhat_data = rhat(data, var_names=var_names, method=method)
|
|
186
|
-
for r_hat in rhat_data.data_vars.values():
|
|
187
|
-
assert ((1 / GOOD_RHAT < r_hat.values) | (r_hat.values < GOOD_RHAT)).all()
|
|
188
|
-
|
|
189
|
-
# In None case check that all varnames from rhat_data match input data
|
|
190
|
-
if var_names is None:
|
|
191
|
-
assert list(rhat_data.data_vars) == list(data.data_vars)
|
|
192
|
-
|
|
193
|
-
@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
|
|
194
|
-
def test_rhat_nan(self, method):
|
|
195
|
-
"""Confirm R-hat statistic returns nan."""
|
|
196
|
-
data = np.random.randn(4, 100)
|
|
197
|
-
data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
|
|
198
|
-
rhat_data = rhat(data, method=method)
|
|
199
|
-
assert np.isnan(rhat_data)
|
|
200
|
-
if method == "rank":
|
|
201
|
-
assert np.isnan(_rhat(rhat_data))
|
|
202
|
-
|
|
203
|
-
@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
|
|
204
|
-
@pytest.mark.parametrize("chain", (None, 1, 2))
|
|
205
|
-
@pytest.mark.parametrize("draw", (1, 2, 3, 4))
|
|
206
|
-
def test_rhat_shape(self, method, chain, draw):
|
|
207
|
-
"""Confirm R-hat statistic returns nan."""
|
|
208
|
-
data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
|
|
209
|
-
if (chain in (None, 1)) or (draw < 4):
|
|
210
|
-
rhat_data = rhat(data, method=method)
|
|
211
|
-
assert np.isnan(rhat_data)
|
|
212
|
-
else:
|
|
213
|
-
rhat_data = rhat(data, method=method)
|
|
214
|
-
assert not np.isnan(rhat_data)
|
|
215
|
-
|
|
216
|
-
def test_rhat_bad(self):
|
|
217
|
-
"""Confirm rank normalized Split R-hat statistic is
|
|
218
|
-
far from 1 for a small number of samples."""
|
|
219
|
-
r_hat = rhat(np.vstack([20 + np.random.randn(1, 100), np.random.randn(1, 100)]))
|
|
220
|
-
assert 1 / GOOD_RHAT > r_hat or GOOD_RHAT < r_hat
|
|
221
|
-
|
|
222
|
-
def test_rhat_bad_method(self):
|
|
223
|
-
with pytest.raises(TypeError):
|
|
224
|
-
rhat(np.random.randn(2, 300), method="wrong_method")
|
|
225
|
-
|
|
226
|
-
def test_rhat_ndarray(self):
|
|
227
|
-
with pytest.raises(TypeError):
|
|
228
|
-
rhat(np.random.randn(2, 300, 10))
|
|
229
|
-
|
|
230
|
-
@pytest.mark.parametrize(
|
|
231
|
-
"method",
|
|
232
|
-
(
|
|
233
|
-
"bulk",
|
|
234
|
-
"tail",
|
|
235
|
-
"quantile",
|
|
236
|
-
"local",
|
|
237
|
-
"mean",
|
|
238
|
-
"sd",
|
|
239
|
-
"median",
|
|
240
|
-
"mad",
|
|
241
|
-
"z_scale",
|
|
242
|
-
"folded",
|
|
243
|
-
"identity",
|
|
244
|
-
),
|
|
245
|
-
)
|
|
246
|
-
@pytest.mark.parametrize("relative", (True, False))
|
|
247
|
-
def test_effective_sample_size_array(self, data, method, relative):
|
|
248
|
-
n_low = 100 / 400 if relative else 100
|
|
249
|
-
n_high = 800 / 400 if relative else 800
|
|
250
|
-
if method in ("quantile", "tail"):
|
|
251
|
-
ess_hat = ess(data, method=method, prob=0.34, relative=relative)
|
|
252
|
-
if method == "tail":
|
|
253
|
-
assert ess_hat > n_low
|
|
254
|
-
assert ess_hat < n_high
|
|
255
|
-
ess_hat = ess(np.random.randn(4, 100), method=method, relative=relative)
|
|
256
|
-
assert ess_hat > n_low
|
|
257
|
-
assert ess_hat < n_high
|
|
258
|
-
ess_hat = ess(
|
|
259
|
-
np.random.randn(4, 100), method=method, prob=(0.2, 0.8), relative=relative
|
|
260
|
-
)
|
|
261
|
-
elif method == "local":
|
|
262
|
-
ess_hat = ess(
|
|
263
|
-
np.random.randn(4, 100), method=method, prob=(0.2, 0.3), relative=relative
|
|
264
|
-
)
|
|
265
|
-
else:
|
|
266
|
-
ess_hat = ess(np.random.randn(4, 100), method=method, relative=relative)
|
|
267
|
-
assert ess_hat > n_low
|
|
268
|
-
assert ess_hat < n_high
|
|
269
|
-
|
|
270
|
-
@pytest.mark.parametrize(
|
|
271
|
-
"method",
|
|
272
|
-
(
|
|
273
|
-
"bulk",
|
|
274
|
-
"tail",
|
|
275
|
-
"quantile",
|
|
276
|
-
"local",
|
|
277
|
-
"mean",
|
|
278
|
-
"sd",
|
|
279
|
-
"median",
|
|
280
|
-
"mad",
|
|
281
|
-
"z_scale",
|
|
282
|
-
"folded",
|
|
283
|
-
"identity",
|
|
284
|
-
),
|
|
285
|
-
)
|
|
286
|
-
@pytest.mark.parametrize("relative", (True, False))
|
|
287
|
-
@pytest.mark.parametrize("chain", (None, 1, 2))
|
|
288
|
-
@pytest.mark.parametrize("draw", (1, 2, 3, 4))
|
|
289
|
-
@pytest.mark.parametrize("use_nan", (True, False))
|
|
290
|
-
def test_effective_sample_size_nan(self, method, relative, chain, draw, use_nan):
|
|
291
|
-
data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
|
|
292
|
-
if use_nan:
|
|
293
|
-
data[0] = np.nan
|
|
294
|
-
if method in ("quantile", "tail"):
|
|
295
|
-
ess_value = ess(data, method=method, prob=0.34, relative=relative)
|
|
296
|
-
elif method == "local":
|
|
297
|
-
ess_value = ess(data, method=method, prob=(0.2, 0.3), relative=relative)
|
|
298
|
-
else:
|
|
299
|
-
ess_value = ess(data, method=method, relative=relative)
|
|
300
|
-
if (draw < 4) or use_nan:
|
|
301
|
-
assert np.isnan(ess_value)
|
|
302
|
-
else:
|
|
303
|
-
assert not np.isnan(ess_value)
|
|
304
|
-
# test following only once tests are run
|
|
305
|
-
if (method == "bulk") and (not relative) and (chain is None) and (draw == 4):
|
|
306
|
-
if use_nan:
|
|
307
|
-
assert np.isnan(_ess(data))
|
|
308
|
-
else:
|
|
309
|
-
assert not np.isnan(_ess(data))
|
|
310
|
-
|
|
311
|
-
@pytest.mark.parametrize("relative", (True, False))
|
|
312
|
-
def test_effective_sample_size_missing_prob(self, relative):
|
|
313
|
-
with pytest.raises(TypeError):
|
|
314
|
-
ess(np.random.randn(4, 100), method="quantile", relative=relative)
|
|
315
|
-
with pytest.raises(TypeError):
|
|
316
|
-
_ess_quantile(np.random.randn(4, 100), prob=None, relative=relative)
|
|
317
|
-
with pytest.raises(TypeError):
|
|
318
|
-
ess(np.random.randn(4, 100), method="local", relative=relative)
|
|
319
|
-
|
|
320
|
-
@pytest.mark.parametrize("relative", (True, False))
|
|
321
|
-
def test_effective_sample_size_too_many_probs(self, relative):
|
|
322
|
-
with pytest.raises(ValueError):
|
|
323
|
-
ess(np.random.randn(4, 100), method="local", prob=[0.1, 0.2, 0.9], relative=relative)
|
|
324
|
-
|
|
325
|
-
def test_effective_sample_size_constant(self):
|
|
326
|
-
assert ess(np.ones((4, 100))) == 400
|
|
327
|
-
|
|
328
|
-
def test_effective_sample_size_bad_method(self):
|
|
329
|
-
with pytest.raises(TypeError):
|
|
330
|
-
ess(np.random.randn(4, 100), method="wrong_method")
|
|
331
|
-
|
|
332
|
-
def test_effective_sample_size_ndarray(self):
|
|
333
|
-
with pytest.raises(TypeError):
|
|
334
|
-
ess(np.random.randn(2, 300, 10))
|
|
335
|
-
|
|
336
|
-
@pytest.mark.parametrize(
|
|
337
|
-
"method",
|
|
338
|
-
(
|
|
339
|
-
"bulk",
|
|
340
|
-
"tail",
|
|
341
|
-
"quantile",
|
|
342
|
-
"local",
|
|
343
|
-
"mean",
|
|
344
|
-
"sd",
|
|
345
|
-
"median",
|
|
346
|
-
"mad",
|
|
347
|
-
"z_scale",
|
|
348
|
-
"folded",
|
|
349
|
-
"identity",
|
|
350
|
-
),
|
|
351
|
-
)
|
|
352
|
-
@pytest.mark.parametrize("relative", (True, False))
|
|
353
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
|
|
354
|
-
def test_effective_sample_size_dataset(self, data, method, var_names, relative):
|
|
355
|
-
n_low = 100 / (data.chain.size * data.draw.size) if relative else 100
|
|
356
|
-
if method in ("quantile", "tail"):
|
|
357
|
-
ess_hat = ess(data, var_names=var_names, method=method, prob=0.34, relative=relative)
|
|
358
|
-
elif method == "local":
|
|
359
|
-
ess_hat = ess(
|
|
360
|
-
data, var_names=var_names, method=method, prob=(0.2, 0.3), relative=relative
|
|
361
|
-
)
|
|
362
|
-
else:
|
|
363
|
-
ess_hat = ess(data, var_names=var_names, method=method, relative=relative)
|
|
364
|
-
assert np.all(ess_hat.mu.values > n_low) # This might break if the data is regenerated
|
|
365
|
-
|
|
366
|
-
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
|
|
367
|
-
def test_mcse_array(self, mcse_method):
|
|
368
|
-
if mcse_method == "quantile":
|
|
369
|
-
mcse_hat = mcse(np.random.randn(4, 100), method=mcse_method, prob=0.34)
|
|
370
|
-
else:
|
|
371
|
-
mcse_hat = mcse(np.random.randn(4, 100), method=mcse_method)
|
|
372
|
-
assert mcse_hat
|
|
373
|
-
|
|
374
|
-
def test_mcse_ndarray(self):
|
|
375
|
-
with pytest.raises(TypeError):
|
|
376
|
-
mcse(np.random.randn(2, 300, 10))
|
|
377
|
-
|
|
378
|
-
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
|
|
379
|
-
@pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
|
|
380
|
-
def test_mcse_dataset(self, data, mcse_method, var_names):
|
|
381
|
-
if mcse_method == "quantile":
|
|
382
|
-
mcse_hat = mcse(data, var_names=var_names, method=mcse_method, prob=0.34)
|
|
383
|
-
else:
|
|
384
|
-
mcse_hat = mcse(data, var_names=var_names, method=mcse_method)
|
|
385
|
-
assert mcse_hat # This might break if the data is regenerated
|
|
386
|
-
|
|
387
|
-
@pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
|
|
388
|
-
@pytest.mark.parametrize("chain", (None, 1, 2))
|
|
389
|
-
@pytest.mark.parametrize("draw", (1, 2, 3, 4))
|
|
390
|
-
@pytest.mark.parametrize("use_nan", (True, False))
|
|
391
|
-
def test_mcse_nan(self, mcse_method, chain, draw, use_nan):
|
|
392
|
-
data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
|
|
393
|
-
if use_nan:
|
|
394
|
-
data[0] = np.nan
|
|
395
|
-
if mcse_method == "quantile":
|
|
396
|
-
mcse_hat = mcse(data, method=mcse_method, prob=0.34)
|
|
397
|
-
else:
|
|
398
|
-
mcse_hat = mcse(data, method=mcse_method)
|
|
399
|
-
if draw < 4 or use_nan:
|
|
400
|
-
assert np.isnan(mcse_hat)
|
|
401
|
-
else:
|
|
402
|
-
assert not np.isnan(mcse_hat)
|
|
403
|
-
|
|
404
|
-
@pytest.mark.parametrize("method", ("wrong_method", "quantile"))
|
|
405
|
-
def test_mcse_bad_method(self, data, method):
|
|
406
|
-
with pytest.raises(TypeError):
|
|
407
|
-
mcse(data, method=method, prob=None)
|
|
408
|
-
|
|
409
|
-
@pytest.mark.parametrize("draws", (3, 4, 100))
|
|
410
|
-
@pytest.mark.parametrize("chains", (None, 1, 2))
|
|
411
|
-
def test_multichain_summary_array(self, draws, chains):
|
|
412
|
-
"""Test multichain statistics against individual functions."""
|
|
413
|
-
if chains is None:
|
|
414
|
-
ary = np.random.randn(draws)
|
|
415
|
-
else:
|
|
416
|
-
ary = np.random.randn(chains, draws)
|
|
417
|
-
|
|
418
|
-
mcse_mean_hat = mcse(ary, method="mean")
|
|
419
|
-
mcse_sd_hat = mcse(ary, method="sd")
|
|
420
|
-
ess_bulk_hat = ess(ary, method="bulk")
|
|
421
|
-
ess_tail_hat = ess(ary, method="tail")
|
|
422
|
-
rhat_hat = _rhat_rank(ary)
|
|
423
|
-
(
|
|
424
|
-
mcse_mean_hat_,
|
|
425
|
-
mcse_sd_hat_,
|
|
426
|
-
ess_bulk_hat_,
|
|
427
|
-
ess_tail_hat_,
|
|
428
|
-
rhat_hat_,
|
|
429
|
-
) = _multichain_statistics(ary)
|
|
430
|
-
if draws == 3:
|
|
431
|
-
assert np.isnan(
|
|
432
|
-
(
|
|
433
|
-
mcse_mean_hat,
|
|
434
|
-
mcse_sd_hat,
|
|
435
|
-
ess_bulk_hat,
|
|
436
|
-
ess_tail_hat,
|
|
437
|
-
rhat_hat,
|
|
438
|
-
)
|
|
439
|
-
).all()
|
|
440
|
-
assert np.isnan(
|
|
441
|
-
(
|
|
442
|
-
mcse_mean_hat_,
|
|
443
|
-
mcse_sd_hat_,
|
|
444
|
-
ess_bulk_hat_,
|
|
445
|
-
ess_tail_hat_,
|
|
446
|
-
rhat_hat_,
|
|
447
|
-
)
|
|
448
|
-
).all()
|
|
449
|
-
else:
|
|
450
|
-
assert_almost_equal(mcse_mean_hat, mcse_mean_hat_)
|
|
451
|
-
assert_almost_equal(mcse_sd_hat, mcse_sd_hat_)
|
|
452
|
-
assert_almost_equal(ess_bulk_hat, ess_bulk_hat_)
|
|
453
|
-
assert_almost_equal(ess_tail_hat, ess_tail_hat_)
|
|
454
|
-
if chains in (None, 1):
|
|
455
|
-
assert np.isnan(rhat_hat)
|
|
456
|
-
assert np.isnan(rhat_hat_)
|
|
457
|
-
else:
|
|
458
|
-
assert round(rhat_hat, 3) == round(rhat_hat_, 3)
|
|
459
|
-
|
|
460
|
-
def test_ks_summary(self):
|
|
461
|
-
"""Instead of psislw data, this test uses fake data."""
|
|
462
|
-
pareto_tail_indices = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2])
|
|
463
|
-
with pytest.warns(UserWarning):
|
|
464
|
-
summary = ks_summary(pareto_tail_indices)
|
|
465
|
-
assert summary is not None
|
|
466
|
-
pareto_tail_indices2 = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.6])
|
|
467
|
-
with pytest.warns(UserWarning):
|
|
468
|
-
summary2 = ks_summary(pareto_tail_indices2)
|
|
469
|
-
assert summary2 is not None
|
|
470
|
-
|
|
471
|
-
@pytest.mark.parametrize("size", [100, 101])
|
|
472
|
-
@pytest.mark.parametrize("batches", [1, 2, 3, 5, 7])
|
|
473
|
-
@pytest.mark.parametrize("ndim", [1, 2, 3])
|
|
474
|
-
@pytest.mark.parametrize("circular", [False, True])
|
|
475
|
-
def test_mc_error(self, size, batches, ndim, circular):
|
|
476
|
-
x = np.random.randn(size, ndim).squeeze() # pylint: disable=no-member
|
|
477
|
-
assert _mc_error(x, batches=batches, circular=circular) is not None
|
|
478
|
-
|
|
479
|
-
@pytest.mark.parametrize("size", [100, 101])
|
|
480
|
-
@pytest.mark.parametrize("ndim", [1, 2, 3])
|
|
481
|
-
def test_mc_error_nan(self, size, ndim):
|
|
482
|
-
x = np.random.randn(size, ndim).squeeze() # pylint: disable=no-member
|
|
483
|
-
x[0] = np.nan
|
|
484
|
-
if ndim != 1:
|
|
485
|
-
assert np.isnan(_mc_error(x)).all()
|
|
486
|
-
else:
|
|
487
|
-
assert np.isnan(_mc_error(x))
|
|
488
|
-
|
|
489
|
-
@pytest.mark.parametrize("func", ("_mcse_quantile", "_z_scale"))
|
|
490
|
-
def test_nan_behaviour(self, func):
|
|
491
|
-
data = np.random.randn(100, 4)
|
|
492
|
-
data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
|
|
493
|
-
if func == "_mcse_quantile":
|
|
494
|
-
assert np.isnan(_mcse_quantile(data, 0.5)).all(None)
|
|
495
|
-
elif packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
|
|
496
|
-
assert not np.isnan(_z_scale(data)).all(None)
|
|
497
|
-
assert not np.isnan(_z_scale(data)).any(None)
|
|
498
|
-
else:
|
|
499
|
-
assert np.isnan(_z_scale(data)).sum() == 1
|
|
500
|
-
|
|
501
|
-
@pytest.mark.parametrize("chains", (None, 1, 2, 3))
|
|
502
|
-
@pytest.mark.parametrize("draws", (2, 3, 100, 101))
|
|
503
|
-
def test_split_chain_dims(self, chains, draws):
|
|
504
|
-
if chains is None:
|
|
505
|
-
data = np.random.randn(draws)
|
|
506
|
-
else:
|
|
507
|
-
data = np.random.randn(chains, draws)
|
|
508
|
-
split_data = _split_chains(data)
|
|
509
|
-
if chains is None:
|
|
510
|
-
chains = 1
|
|
511
|
-
assert split_data.shape == (chains * 2, draws // 2)
|