arviz 0.16.1__py3-none-any.whl → 0.17.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +1 -1
- arviz/data/inference_data.py +34 -7
- arviz/data/io_beanmachine.py +6 -1
- arviz/data/io_cmdstanpy.py +439 -50
- arviz/data/io_pyjags.py +5 -2
- arviz/data/io_pystan.py +1 -2
- arviz/labels.py +2 -0
- arviz/plots/backends/bokeh/bpvplot.py +7 -2
- arviz/plots/backends/bokeh/compareplot.py +7 -4
- arviz/plots/backends/bokeh/densityplot.py +0 -1
- arviz/plots/backends/bokeh/distplot.py +0 -2
- arviz/plots/backends/bokeh/forestplot.py +3 -5
- arviz/plots/backends/bokeh/kdeplot.py +0 -2
- arviz/plots/backends/bokeh/pairplot.py +0 -4
- arviz/plots/backends/matplotlib/bfplot.py +0 -1
- arviz/plots/backends/matplotlib/bpvplot.py +3 -3
- arviz/plots/backends/matplotlib/compareplot.py +1 -1
- arviz/plots/backends/matplotlib/dotplot.py +1 -1
- arviz/plots/backends/matplotlib/forestplot.py +2 -4
- arviz/plots/backends/matplotlib/kdeplot.py +0 -1
- arviz/plots/backends/matplotlib/khatplot.py +0 -1
- arviz/plots/backends/matplotlib/lmplot.py +4 -5
- arviz/plots/backends/matplotlib/pairplot.py +0 -1
- arviz/plots/backends/matplotlib/ppcplot.py +8 -5
- arviz/plots/backends/matplotlib/traceplot.py +1 -2
- arviz/plots/bfplot.py +7 -6
- arviz/plots/bpvplot.py +7 -2
- arviz/plots/compareplot.py +2 -2
- arviz/plots/ecdfplot.py +37 -112
- arviz/plots/elpdplot.py +1 -1
- arviz/plots/essplot.py +2 -2
- arviz/plots/kdeplot.py +0 -1
- arviz/plots/pairplot.py +1 -1
- arviz/plots/plot_utils.py +0 -1
- arviz/plots/ppcplot.py +51 -45
- arviz/plots/separationplot.py +0 -1
- arviz/stats/__init__.py +2 -0
- arviz/stats/density_utils.py +2 -2
- arviz/stats/diagnostics.py +2 -3
- arviz/stats/ecdf_utils.py +165 -0
- arviz/stats/stats.py +241 -38
- arviz/stats/stats_utils.py +36 -7
- arviz/tests/base_tests/test_data.py +73 -5
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1
- arviz/tests/base_tests/test_plots_matplotlib.py +24 -1
- arviz/tests/base_tests/test_stats.py +43 -1
- arviz/tests/base_tests/test_stats_ecdf_utils.py +153 -0
- arviz/tests/base_tests/test_stats_utils.py +3 -3
- arviz/tests/external_tests/test_data_beanmachine.py +2 -0
- arviz/tests/external_tests/test_data_numpyro.py +3 -3
- arviz/tests/external_tests/test_data_pyjags.py +3 -1
- arviz/tests/external_tests/test_data_pyro.py +3 -3
- arviz/tests/helpers.py +8 -8
- arviz/utils.py +15 -7
- arviz/wrappers/wrap_pymc.py +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/METADATA +16 -15
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/RECORD +60 -58
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/WHEEL +1 -1
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/LICENSE +0 -0
- {arviz-0.16.1.dist-info → arviz-0.17.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import scipy.stats
|
|
5
|
+
from ...stats.ecdf_utils import (
|
|
6
|
+
compute_ecdf,
|
|
7
|
+
ecdf_confidence_band,
|
|
8
|
+
_get_ecdf_points,
|
|
9
|
+
_simulate_ecdf,
|
|
10
|
+
_get_pointwise_confidence_band,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_compute_ecdf():
|
|
15
|
+
"""Test compute_ecdf function."""
|
|
16
|
+
sample = np.array([1, 2, 3, 3, 4, 5])
|
|
17
|
+
eval_points = np.arange(0, 7, 0.1)
|
|
18
|
+
ecdf_expected = (sample[:, None] <= eval_points).mean(axis=0)
|
|
19
|
+
assert np.allclose(compute_ecdf(sample, eval_points), ecdf_expected)
|
|
20
|
+
assert np.allclose(compute_ecdf(sample / 2 + 10, eval_points / 2 + 10), ecdf_expected)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.mark.parametrize("difference", [True, False])
|
|
24
|
+
def test_get_ecdf_points(difference):
|
|
25
|
+
"""Test _get_ecdf_points."""
|
|
26
|
+
# if first point already outside support, no need to insert it
|
|
27
|
+
sample = np.array([1, 2, 3, 3, 4, 5, 5])
|
|
28
|
+
eval_points = np.arange(-1, 7, 0.1)
|
|
29
|
+
x, y = _get_ecdf_points(sample, eval_points, difference)
|
|
30
|
+
assert np.array_equal(x, eval_points)
|
|
31
|
+
assert np.array_equal(y, compute_ecdf(sample, eval_points))
|
|
32
|
+
|
|
33
|
+
# if first point is inside support, insert it if not in difference mode
|
|
34
|
+
eval_points = np.arange(1, 6, 0.1)
|
|
35
|
+
x, y = _get_ecdf_points(sample, eval_points, difference)
|
|
36
|
+
assert len(x) == len(eval_points) + 1 - difference
|
|
37
|
+
assert len(y) == len(eval_points) + 1 - difference
|
|
38
|
+
|
|
39
|
+
# if not in difference mode, first point should be (eval_points[0], 0)
|
|
40
|
+
if not difference:
|
|
41
|
+
assert x[0] == eval_points[0]
|
|
42
|
+
assert y[0] == 0
|
|
43
|
+
assert np.allclose(x[1:], eval_points)
|
|
44
|
+
assert np.allclose(y[1:], compute_ecdf(sample, eval_points))
|
|
45
|
+
assert x[-1] == eval_points[-1]
|
|
46
|
+
assert y[-1] == 1
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@pytest.mark.parametrize(
|
|
50
|
+
"dist", [scipy.stats.norm(3, 10), scipy.stats.binom(10, 0.5)], ids=["continuous", "discrete"]
|
|
51
|
+
)
|
|
52
|
+
@pytest.mark.parametrize("seed", [32, 87])
|
|
53
|
+
def test_simulate_ecdf(dist, seed):
|
|
54
|
+
"""Test _simulate_ecdf."""
|
|
55
|
+
ndraws = 1000
|
|
56
|
+
eval_points = np.arange(0, 1, 0.1)
|
|
57
|
+
|
|
58
|
+
rvs = dist.rvs
|
|
59
|
+
|
|
60
|
+
random_state = np.random.default_rng(seed)
|
|
61
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, rvs, random_state=random_state)
|
|
62
|
+
random_state = np.random.default_rng(seed)
|
|
63
|
+
ecdf_expected = compute_ecdf(np.sort(rvs(ndraws, random_state=random_state)), eval_points)
|
|
64
|
+
|
|
65
|
+
assert np.allclose(ecdf, ecdf_expected)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@pytest.mark.parametrize("prob", [0.8, 0.9])
|
|
69
|
+
@pytest.mark.parametrize(
|
|
70
|
+
"dist", [scipy.stats.norm(3, 10), scipy.stats.poisson(100)], ids=["continuous", "discrete"]
|
|
71
|
+
)
|
|
72
|
+
@pytest.mark.parametrize("ndraws", [10_000])
|
|
73
|
+
def test_get_pointwise_confidence_band(dist, prob, ndraws, num_trials=1_000, seed=57):
|
|
74
|
+
"""Test _get_pointwise_confidence_band."""
|
|
75
|
+
eval_points = np.linspace(*dist.interval(0.99), 10)
|
|
76
|
+
cdf_at_eval_points = dist.cdf(eval_points)
|
|
77
|
+
|
|
78
|
+
ecdf_lower, ecdf_upper = _get_pointwise_confidence_band(prob, ndraws, cdf_at_eval_points)
|
|
79
|
+
|
|
80
|
+
# check basic properties
|
|
81
|
+
assert np.all(ecdf_lower >= 0)
|
|
82
|
+
assert np.all(ecdf_upper <= 1)
|
|
83
|
+
assert np.all(ecdf_lower <= ecdf_upper)
|
|
84
|
+
|
|
85
|
+
# use simulation to estimate lower and upper bounds on pointwise probability
|
|
86
|
+
in_interval = []
|
|
87
|
+
random_state = np.random.default_rng(seed)
|
|
88
|
+
for _ in range(num_trials):
|
|
89
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
|
|
90
|
+
in_interval.append((ecdf_lower <= ecdf) & (ecdf < ecdf_upper))
|
|
91
|
+
asymptotic_dist = scipy.stats.norm(
|
|
92
|
+
np.mean(in_interval, axis=0), scipy.stats.sem(in_interval, axis=0)
|
|
93
|
+
)
|
|
94
|
+
prob_lower, prob_upper = asymptotic_dist.interval(0.999)
|
|
95
|
+
|
|
96
|
+
# check target probability within all bounds
|
|
97
|
+
assert np.all(prob_lower <= prob)
|
|
98
|
+
assert np.all(prob <= prob_upper)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@pytest.mark.parametrize("prob", [0.8, 0.9])
|
|
102
|
+
@pytest.mark.parametrize(
|
|
103
|
+
"dist, rvs",
|
|
104
|
+
[
|
|
105
|
+
(scipy.stats.norm(3, 10), scipy.stats.norm(3, 10).rvs),
|
|
106
|
+
(scipy.stats.norm(3, 10), None),
|
|
107
|
+
(scipy.stats.poisson(100), scipy.stats.poisson(100).rvs),
|
|
108
|
+
],
|
|
109
|
+
ids=["continuous", "continuous default rvs", "discrete"],
|
|
110
|
+
)
|
|
111
|
+
@pytest.mark.parametrize("ndraws", [10_000])
|
|
112
|
+
@pytest.mark.parametrize("method", ["pointwise", "simulated"])
|
|
113
|
+
def test_ecdf_confidence_band(dist, rvs, prob, ndraws, method, num_trials=1_000, seed=57):
|
|
114
|
+
"""Test test_ecdf_confidence_band."""
|
|
115
|
+
eval_points = np.linspace(*dist.interval(0.99), 10)
|
|
116
|
+
cdf_at_eval_points = dist.cdf(eval_points)
|
|
117
|
+
random_state = np.random.default_rng(seed)
|
|
118
|
+
|
|
119
|
+
ecdf_lower, ecdf_upper = ecdf_confidence_band(
|
|
120
|
+
ndraws,
|
|
121
|
+
eval_points,
|
|
122
|
+
cdf_at_eval_points,
|
|
123
|
+
prob=prob,
|
|
124
|
+
rvs=rvs,
|
|
125
|
+
random_state=random_state,
|
|
126
|
+
method=method,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if method == "pointwise":
|
|
130
|
+
# these values tested elsewhere, we just make sure they're the same
|
|
131
|
+
ecdf_lower_pointwise, ecdf_upper_pointwise = _get_pointwise_confidence_band(
|
|
132
|
+
prob, ndraws, cdf_at_eval_points
|
|
133
|
+
)
|
|
134
|
+
assert np.array_equal(ecdf_lower, ecdf_lower_pointwise)
|
|
135
|
+
assert np.array_equal(ecdf_upper, ecdf_upper_pointwise)
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
# check basic properties
|
|
139
|
+
assert np.all(ecdf_lower >= 0)
|
|
140
|
+
assert np.all(ecdf_upper <= 1)
|
|
141
|
+
assert np.all(ecdf_lower <= ecdf_upper)
|
|
142
|
+
|
|
143
|
+
# use simulation to estimate lower and upper bounds on simultaneous probability
|
|
144
|
+
in_envelope = []
|
|
145
|
+
random_state = np.random.default_rng(seed)
|
|
146
|
+
for _ in range(num_trials):
|
|
147
|
+
ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
|
|
148
|
+
in_envelope.append(np.all(ecdf_lower <= ecdf) & np.all(ecdf < ecdf_upper))
|
|
149
|
+
asymptotic_dist = scipy.stats.norm(np.mean(in_envelope), scipy.stats.sem(in_envelope))
|
|
150
|
+
prob_lower, prob_upper = asymptotic_dist.interval(0.999)
|
|
151
|
+
|
|
152
|
+
# check target probability within bounds
|
|
153
|
+
assert prob_lower <= prob <= prob_upper
|
|
@@ -344,9 +344,9 @@ def test_variance_bad_data():
|
|
|
344
344
|
|
|
345
345
|
def test_histogram():
|
|
346
346
|
school = load_arviz_data("non_centered_eight").posterior["mu"].values
|
|
347
|
-
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.
|
|
348
|
-
k_dens_np, *_ = np.histogram(school, bins=[-np.
|
|
349
|
-
k_count_np, *_ = np.histogram(school, bins=[-np.
|
|
347
|
+
k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
|
|
348
|
+
k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
|
|
349
|
+
k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
|
|
350
350
|
assert np.allclose(k_count_az, k_count_np)
|
|
351
351
|
assert np.allclose(k_dens_az, k_dens_np)
|
|
352
352
|
|
|
@@ -11,6 +11,8 @@ from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
|
|
|
11
11
|
load_cached_models,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
|
+
pytest.skip("Ignore beanmachine tests until it supports pytorch 2", allow_module_level=True)
|
|
15
|
+
|
|
14
16
|
# Skip all tests if beanmachine or pytorch not installed
|
|
15
17
|
torch = importorskip("torch")
|
|
16
18
|
bm = importorskip("beanmachine.ppl")
|
|
@@ -101,8 +101,8 @@ class TestDataNumPyro:
|
|
|
101
101
|
assert not fails
|
|
102
102
|
|
|
103
103
|
# test dims
|
|
104
|
-
dims = inference_data.posterior_predictive.
|
|
105
|
-
pred_dims = inference_data.predictions.
|
|
104
|
+
dims = inference_data.posterior_predictive.sizes["school"]
|
|
105
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
106
106
|
assert dims == 8
|
|
107
107
|
assert pred_dims == 8
|
|
108
108
|
|
|
@@ -240,7 +240,7 @@ class TestDataNumPyro:
|
|
|
240
240
|
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
241
241
|
predictions = predictions_data
|
|
242
242
|
inference_data = from_numpyro(predictions=predictions, num_chains=chains)
|
|
243
|
-
nchains = inference_data.predictions.
|
|
243
|
+
nchains = inference_data.predictions.sizes["chain"]
|
|
244
244
|
assert nchains == chains
|
|
245
245
|
|
|
246
246
|
@pytest.mark.parametrize("nchains", [1, 2])
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# pylint: disable=no-member, invalid-name, redefined-outer-name, unused-import
|
|
2
|
+
import sys
|
|
2
3
|
import typing as tp
|
|
3
4
|
|
|
4
5
|
import numpy as np
|
|
@@ -12,6 +13,8 @@ from ...data.io_pyjags import (
|
|
|
12
13
|
)
|
|
13
14
|
from ..helpers import check_multiple_attrs, eight_schools_params
|
|
14
15
|
|
|
16
|
+
pytest.skip("Uses deprecated numpy C-api", allow_module_level=True)
|
|
17
|
+
|
|
15
18
|
PYJAGS_POSTERIOR_DICT = {
|
|
16
19
|
"b": np.random.randn(3, 10, 3),
|
|
17
20
|
"int": np.random.randn(1, 10, 3),
|
|
@@ -40,7 +43,6 @@ def verify_equality_of_numpy_values_dictionaries(
|
|
|
40
43
|
return True
|
|
41
44
|
|
|
42
45
|
|
|
43
|
-
@pytest.mark.skip("Crashes Python")
|
|
44
46
|
class TestDataPyJAGSWithoutEstimation:
|
|
45
47
|
def test_convert_pyjags_samples_dictionary_to_arviz_samples_dictionary(self):
|
|
46
48
|
arviz_samples_dict_from_pyjags_samples_dict = _convert_pyjags_dict_to_arviz_dict(
|
|
@@ -83,8 +83,8 @@ class TestDataPyro:
|
|
|
83
83
|
assert not fails
|
|
84
84
|
|
|
85
85
|
# test dims
|
|
86
|
-
dims = inference_data.posterior_predictive.
|
|
87
|
-
pred_dims = inference_data.predictions.
|
|
86
|
+
dims = inference_data.posterior_predictive.sizes["school"]
|
|
87
|
+
pred_dims = inference_data.predictions.sizes["school_pred"]
|
|
88
88
|
assert dims == 8
|
|
89
89
|
assert pred_dims == 8
|
|
90
90
|
|
|
@@ -225,7 +225,7 @@ class TestDataPyro:
|
|
|
225
225
|
def test_inference_data_num_chains(self, predictions_data, chains):
|
|
226
226
|
predictions = predictions_data
|
|
227
227
|
inference_data = from_pyro(predictions=predictions, num_chains=chains)
|
|
228
|
-
nchains = inference_data.predictions.
|
|
228
|
+
nchains = inference_data.predictions.sizes["chain"]
|
|
229
229
|
assert nchains == chains
|
|
230
230
|
|
|
231
231
|
@pytest.mark.parametrize("log_likelihood", [True, False])
|
arviz/tests/helpers.py
CHANGED
|
@@ -432,18 +432,18 @@ def pystan_noncentered_schools(data, draws, chains):
|
|
|
432
432
|
schools_code = """
|
|
433
433
|
data {
|
|
434
434
|
int<lower=0> J;
|
|
435
|
-
real y
|
|
436
|
-
real<lower=0> sigma
|
|
435
|
+
array[J] real y;
|
|
436
|
+
array[J] real<lower=0> sigma;
|
|
437
437
|
}
|
|
438
438
|
|
|
439
439
|
parameters {
|
|
440
440
|
real mu;
|
|
441
441
|
real<lower=0> tau;
|
|
442
|
-
real eta
|
|
442
|
+
array[J] real eta;
|
|
443
443
|
}
|
|
444
444
|
|
|
445
445
|
transformed parameters {
|
|
446
|
-
real theta
|
|
446
|
+
array[J] real theta;
|
|
447
447
|
for (j in 1:J)
|
|
448
448
|
theta[j] = mu + tau * eta[j];
|
|
449
449
|
}
|
|
@@ -456,8 +456,8 @@ def pystan_noncentered_schools(data, draws, chains):
|
|
|
456
456
|
}
|
|
457
457
|
|
|
458
458
|
generated quantities {
|
|
459
|
-
|
|
460
|
-
|
|
459
|
+
array[J] real log_lik;
|
|
460
|
+
array[J] real y_hat;
|
|
461
461
|
for (j in 1:J) {
|
|
462
462
|
log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
|
|
463
463
|
y_hat[j] = normal_rng(theta[j], sigma[j]);
|
|
@@ -487,7 +487,7 @@ def pystan_noncentered_schools(data, draws, chains):
|
|
|
487
487
|
|
|
488
488
|
|
|
489
489
|
def bm_schools_model(data, draws, chains):
|
|
490
|
-
import beanmachine.ppl as bm
|
|
490
|
+
import beanmachine.ppl as bm # pylint: disable=import-error
|
|
491
491
|
import torch
|
|
492
492
|
import torch.distributions as dist
|
|
493
493
|
|
|
@@ -552,7 +552,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None):
|
|
|
552
552
|
("emcee", emcee_schools_model),
|
|
553
553
|
("pyro", pyro_noncentered_schools),
|
|
554
554
|
("numpyro", numpyro_schools_model),
|
|
555
|
-
("beanmachine", bm_schools_model),
|
|
555
|
+
# ("beanmachine", bm_schools_model), # ignore beanmachine until it supports torch>=2
|
|
556
556
|
)
|
|
557
557
|
data_directory = os.path.join(here, "saved_models")
|
|
558
558
|
models = {}
|
arviz/utils.py
CHANGED
|
@@ -21,7 +21,7 @@ def _check_tilde_start(x):
|
|
|
21
21
|
return bool(isinstance(x, str) and x.startswith("~"))
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def _var_names(var_names, data, filter_vars=None):
|
|
24
|
+
def _var_names(var_names, data, filter_vars=None, errors="raise"):
|
|
25
25
|
"""Handle var_names input across arviz.
|
|
26
26
|
|
|
27
27
|
Parameters
|
|
@@ -34,6 +34,8 @@ def _var_names(var_names, data, filter_vars=None):
|
|
|
34
34
|
interpret var_names as substrings of the real variables names. If "regex",
|
|
35
35
|
interpret var_names as regular expressions on the real variables names. A la
|
|
36
36
|
`pandas.filter`.
|
|
37
|
+
errors: {"raise", "ignore"}, optional, default="raise"
|
|
38
|
+
Select either to raise or ignore the invalid names.
|
|
37
39
|
|
|
38
40
|
Returns
|
|
39
41
|
-------
|
|
@@ -44,6 +46,9 @@ def _var_names(var_names, data, filter_vars=None):
|
|
|
44
46
|
f"'filter_vars' can only be None, 'like', or 'regex', got: '{filter_vars}'"
|
|
45
47
|
)
|
|
46
48
|
|
|
49
|
+
if errors not in {"raise", "ignore"}:
|
|
50
|
+
raise ValueError(f"'errors' can only be 'raise', or 'ignore', got: '{errors}'")
|
|
51
|
+
|
|
47
52
|
if var_names is not None:
|
|
48
53
|
if isinstance(data, (list, tuple)):
|
|
49
54
|
all_vars = []
|
|
@@ -66,14 +71,16 @@ def _var_names(var_names, data, filter_vars=None):
|
|
|
66
71
|
)
|
|
67
72
|
|
|
68
73
|
try:
|
|
69
|
-
var_names = _subset_list(
|
|
74
|
+
var_names = _subset_list(
|
|
75
|
+
var_names, all_vars, filter_items=filter_vars, warn=False, errors=errors
|
|
76
|
+
)
|
|
70
77
|
except KeyError as err:
|
|
71
78
|
msg = " ".join(("var names:", f"{err}", "in dataset"))
|
|
72
79
|
raise KeyError(msg) from err
|
|
73
80
|
return var_names
|
|
74
81
|
|
|
75
82
|
|
|
76
|
-
def _subset_list(subset, whole_list, filter_items=None, warn=True):
|
|
83
|
+
def _subset_list(subset, whole_list, filter_items=None, warn=True, errors="raise"):
|
|
77
84
|
"""Handle list subsetting (var_names, groups...) across arviz.
|
|
78
85
|
|
|
79
86
|
Parameters
|
|
@@ -87,6 +94,8 @@ def _subset_list(subset, whole_list, filter_items=None, warn=True):
|
|
|
87
94
|
names. If "like", interpret `subset` as substrings of the elements in
|
|
88
95
|
`whole_list`. If "regex", interpret `subset` as regular expressions to match
|
|
89
96
|
elements in `whole_list`. A la `pandas.filter`.
|
|
97
|
+
errors: {"raise", "ignore"}, optional, default="raise"
|
|
98
|
+
Select either to raise or ignore the invalid names.
|
|
90
99
|
|
|
91
100
|
Returns
|
|
92
101
|
-------
|
|
@@ -95,7 +104,6 @@ def _subset_list(subset, whole_list, filter_items=None, warn=True):
|
|
|
95
104
|
and ``filter_items``.
|
|
96
105
|
"""
|
|
97
106
|
if subset is not None:
|
|
98
|
-
|
|
99
107
|
if isinstance(subset, str):
|
|
100
108
|
subset = [subset]
|
|
101
109
|
|
|
@@ -142,7 +150,7 @@ def _subset_list(subset, whole_list, filter_items=None, warn=True):
|
|
|
142
150
|
subset = [item for item in whole_list for name in subset if re.search(name, item)]
|
|
143
151
|
|
|
144
152
|
existing_items = np.isin(subset, whole_list)
|
|
145
|
-
if not np.all(existing_items):
|
|
153
|
+
if not np.all(existing_items) and (errors == "raise"):
|
|
146
154
|
raise KeyError(f"{np.array(subset)[~existing_items]} are not present")
|
|
147
155
|
|
|
148
156
|
return subset
|
|
@@ -660,7 +668,8 @@ def _load_static_files():
|
|
|
660
668
|
Clone from xarray.core.formatted_html_template.
|
|
661
669
|
"""
|
|
662
670
|
return [
|
|
663
|
-
importlib.resources.files("arviz").joinpath(fname).read_text()
|
|
671
|
+
importlib.resources.files("arviz").joinpath(fname).read_text(encoding="utf-8")
|
|
672
|
+
for fname in STATIC_FILES
|
|
664
673
|
]
|
|
665
674
|
|
|
666
675
|
|
|
@@ -745,7 +754,6 @@ def conditional_dask(func):
|
|
|
745
754
|
|
|
746
755
|
@functools.wraps(func)
|
|
747
756
|
def wrapper(*args, **kwargs):
|
|
748
|
-
|
|
749
757
|
if not Dask.dask_flag:
|
|
750
758
|
return func(*args, **kwargs)
|
|
751
759
|
user_kwargs = kwargs.pop("dask_kwargs", None)
|
arviz/wrappers/wrap_pymc.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
"""Base class for PyMC interface wrappers."""
|
|
3
3
|
from .base import SamplingWrapper
|
|
4
4
|
|
|
5
|
+
|
|
5
6
|
# pylint: disable=abstract-method
|
|
6
7
|
class PyMCSamplingWrapper(SamplingWrapper):
|
|
7
8
|
"""PyMC (4.0+) sampling wrapper base class.
|
|
@@ -21,7 +22,6 @@ class PyMCSamplingWrapper(SamplingWrapper):
|
|
|
21
22
|
import pymc # pylint: disable=import-error
|
|
22
23
|
|
|
23
24
|
with self.model:
|
|
24
|
-
|
|
25
25
|
pymc.set_data(modified_observed_data)
|
|
26
26
|
idata = pymc.sample(
|
|
27
27
|
**self.sample_kwargs,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: arviz
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.17.1
|
|
4
4
|
Summary: Exploratory analysis of Bayesian models
|
|
5
5
|
Home-page: http://github.com/arviz-devs/arviz
|
|
6
6
|
Author: ArviZ Developers
|
|
@@ -15,30 +15,31 @@ Classifier: Programming Language :: Python :: 3
|
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.9
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.10
|
|
17
17
|
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
18
19
|
Classifier: Topic :: Scientific/Engineering
|
|
19
20
|
Classifier: Topic :: Scientific/Engineering :: Visualization
|
|
20
21
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
21
22
|
Requires-Python: >=3.9
|
|
22
23
|
Description-Content-Type: text/markdown
|
|
23
24
|
License-File: LICENSE
|
|
24
|
-
Requires-Dist: setuptools
|
|
25
|
-
Requires-Dist: matplotlib
|
|
26
|
-
Requires-Dist: numpy
|
|
27
|
-
Requires-Dist: scipy
|
|
25
|
+
Requires-Dist: setuptools >=60.0.0
|
|
26
|
+
Requires-Dist: matplotlib >=3.5
|
|
27
|
+
Requires-Dist: numpy <2.0,>=1.22.0
|
|
28
|
+
Requires-Dist: scipy >=1.8.0
|
|
28
29
|
Requires-Dist: packaging
|
|
29
|
-
Requires-Dist: pandas
|
|
30
|
-
Requires-Dist: xarray
|
|
31
|
-
Requires-Dist: h5netcdf
|
|
32
|
-
Requires-Dist: typing-extensions
|
|
33
|
-
Requires-Dist: xarray-einstats
|
|
30
|
+
Requires-Dist: pandas >=1.4.0
|
|
31
|
+
Requires-Dist: xarray >=0.21.0
|
|
32
|
+
Requires-Dist: h5netcdf >=1.0.2
|
|
33
|
+
Requires-Dist: typing-extensions >=4.1.0
|
|
34
|
+
Requires-Dist: xarray-einstats >=0.3
|
|
34
35
|
Provides-Extra: all
|
|
35
36
|
Requires-Dist: numba ; extra == 'all'
|
|
36
37
|
Requires-Dist: netcdf4 ; extra == 'all'
|
|
37
|
-
Requires-Dist: bokeh
|
|
38
|
+
Requires-Dist: bokeh <3.0,>=1.4.0 ; extra == 'all'
|
|
38
39
|
Requires-Dist: contourpy ; extra == 'all'
|
|
39
40
|
Requires-Dist: ujson ; extra == 'all'
|
|
40
41
|
Requires-Dist: dask[distributed] ; extra == 'all'
|
|
41
|
-
Requires-Dist: zarr
|
|
42
|
+
Requires-Dist: zarr >=2.5.0 ; extra == 'all'
|
|
42
43
|
Requires-Dist: xarray-datatree ; extra == 'all'
|
|
43
44
|
|
|
44
45
|
<img src="https://raw.githubusercontent.com/arviz-devs/arviz-project/main/arviz_logos/ArviZ.png#gh-light-mode-only" width=200></img>
|
|
@@ -52,8 +53,7 @@ Requires-Dist: xarray-datatree ; extra == 'all'
|
|
|
52
53
|
[](https://doi.org/10.21105/joss.01143) [](https://doi.org/10.5281/zenodo.2540945)
|
|
53
54
|
[](https://numfocus.org)
|
|
54
55
|
|
|
55
|
-
ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models.
|
|
56
|
-
Includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
|
|
56
|
+
ArviZ (pronounced "AR-_vees_") is a Python package for exploratory analysis of Bayesian models. It includes functions for posterior analysis, data storage, model checking, comparison and diagnostics.
|
|
57
57
|
|
|
58
58
|
### ArviZ in other languages
|
|
59
59
|
ArviZ also has a Julia wrapper available [ArviZ.jl](https://julia.arviz.org/).
|
|
@@ -202,9 +202,10 @@ python setup.py install
|
|
|
202
202
|
|
|
203
203
|
<a href="https://python.arviz.org/en/latest/examples/index.html">And more...</a>
|
|
204
204
|
</div>
|
|
205
|
+
|
|
205
206
|
## Dependencies
|
|
206
207
|
|
|
207
|
-
ArviZ is tested on Python 3.
|
|
208
|
+
ArviZ is tested on Python 3.10, 3.11 and 3.12, and depends on NumPy, SciPy, xarray, and Matplotlib.
|
|
208
209
|
|
|
209
210
|
|
|
210
211
|
## Citation
|