arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,166 +0,0 @@
1
- import pytest
2
-
3
- import numpy as np
4
- import scipy.stats
5
- from ...stats.ecdf_utils import (
6
- compute_ecdf,
7
- ecdf_confidence_band,
8
- _get_ecdf_points,
9
- _simulate_ecdf,
10
- _get_pointwise_confidence_band,
11
- )
12
-
13
- try:
14
- import numba # pylint: disable=unused-import
15
-
16
- numba_options = [True, False] # pylint: disable=invalid-name
17
- except ImportError:
18
- numba_options = [False] # pylint: disable=invalid-name
19
-
20
-
21
- def test_compute_ecdf():
22
- """Test compute_ecdf function."""
23
- sample = np.array([1, 2, 3, 3, 4, 5])
24
- eval_points = np.arange(0, 7, 0.1)
25
- ecdf_expected = (sample[:, None] <= eval_points).mean(axis=0)
26
- assert np.allclose(compute_ecdf(sample, eval_points), ecdf_expected)
27
- assert np.allclose(compute_ecdf(sample / 2 + 10, eval_points / 2 + 10), ecdf_expected)
28
-
29
-
30
- @pytest.mark.parametrize("difference", [True, False])
31
- def test_get_ecdf_points(difference):
32
- """Test _get_ecdf_points."""
33
- # if first point already outside support, no need to insert it
34
- sample = np.array([1, 2, 3, 3, 4, 5, 5])
35
- eval_points = np.arange(-1, 7, 0.1)
36
- x, y = _get_ecdf_points(sample, eval_points, difference)
37
- assert np.array_equal(x, eval_points)
38
- assert np.array_equal(y, compute_ecdf(sample, eval_points))
39
-
40
- # if first point is inside support, insert it if not in difference mode
41
- eval_points = np.arange(1, 6, 0.1)
42
- x, y = _get_ecdf_points(sample, eval_points, difference)
43
- assert len(x) == len(eval_points) + 1 - difference
44
- assert len(y) == len(eval_points) + 1 - difference
45
-
46
- # if not in difference mode, first point should be (eval_points[0], 0)
47
- if not difference:
48
- assert x[0] == eval_points[0]
49
- assert y[0] == 0
50
- assert np.allclose(x[1:], eval_points)
51
- assert np.allclose(y[1:], compute_ecdf(sample, eval_points))
52
- assert x[-1] == eval_points[-1]
53
- assert y[-1] == 1
54
-
55
-
56
- @pytest.mark.parametrize(
57
- "dist", [scipy.stats.norm(3, 10), scipy.stats.binom(10, 0.5)], ids=["continuous", "discrete"]
58
- )
59
- @pytest.mark.parametrize("seed", [32, 87])
60
- def test_simulate_ecdf(dist, seed):
61
- """Test _simulate_ecdf."""
62
- ndraws = 1000
63
- eval_points = np.arange(0, 1, 0.1)
64
-
65
- rvs = dist.rvs
66
-
67
- random_state = np.random.default_rng(seed)
68
- ecdf = _simulate_ecdf(ndraws, eval_points, rvs, random_state=random_state)
69
- random_state = np.random.default_rng(seed)
70
- ecdf_expected = compute_ecdf(np.sort(rvs(ndraws, random_state=random_state)), eval_points)
71
-
72
- assert np.allclose(ecdf, ecdf_expected)
73
-
74
-
75
- @pytest.mark.parametrize("prob", [0.8, 0.9])
76
- @pytest.mark.parametrize(
77
- "dist", [scipy.stats.norm(3, 10), scipy.stats.poisson(100)], ids=["continuous", "discrete"]
78
- )
79
- @pytest.mark.parametrize("ndraws", [10_000])
80
- def test_get_pointwise_confidence_band(dist, prob, ndraws, num_trials=1_000, seed=57):
81
- """Test _get_pointwise_confidence_band."""
82
- eval_points = np.linspace(*dist.interval(0.99), 10)
83
- cdf_at_eval_points = dist.cdf(eval_points)
84
-
85
- ecdf_lower, ecdf_upper = _get_pointwise_confidence_band(prob, ndraws, cdf_at_eval_points)
86
-
87
- # check basic properties
88
- assert np.all(ecdf_lower >= 0)
89
- assert np.all(ecdf_upper <= 1)
90
- assert np.all(ecdf_lower <= ecdf_upper)
91
-
92
- # use simulation to estimate lower and upper bounds on pointwise probability
93
- in_interval = []
94
- random_state = np.random.default_rng(seed)
95
- for _ in range(num_trials):
96
- ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
97
- in_interval.append((ecdf_lower <= ecdf) & (ecdf < ecdf_upper))
98
- asymptotic_dist = scipy.stats.norm(
99
- np.mean(in_interval, axis=0), scipy.stats.sem(in_interval, axis=0)
100
- )
101
- prob_lower, prob_upper = asymptotic_dist.interval(0.999)
102
-
103
- # check target probability within all bounds
104
- assert np.all(prob_lower <= prob)
105
- assert np.all(prob <= prob_upper)
106
-
107
-
108
- @pytest.mark.parametrize("prob", [0.8, 0.9])
109
- @pytest.mark.parametrize(
110
- "dist, rvs",
111
- [
112
- (scipy.stats.norm(3, 10), scipy.stats.norm(3, 10).rvs),
113
- (scipy.stats.norm(3, 10), None),
114
- (scipy.stats.poisson(100), scipy.stats.poisson(100).rvs),
115
- ],
116
- ids=["continuous", "continuous default rvs", "discrete"],
117
- )
118
- @pytest.mark.parametrize("ndraws", [10_000])
119
- @pytest.mark.parametrize("method", ["pointwise", "optimized", "simulated"])
120
- @pytest.mark.parametrize("use_numba", numba_options)
121
- def test_ecdf_confidence_band(
122
- dist, rvs, prob, ndraws, method, use_numba, num_trials=1_000, seed=57
123
- ):
124
- """Test test_ecdf_confidence_band."""
125
- if use_numba and method != "optimized":
126
- pytest.skip("Numba only used in optimized method")
127
-
128
- eval_points = np.linspace(*dist.interval(0.99), 10)
129
- cdf_at_eval_points = dist.cdf(eval_points)
130
- random_state = np.random.default_rng(seed)
131
-
132
- ecdf_lower, ecdf_upper = ecdf_confidence_band(
133
- ndraws,
134
- eval_points,
135
- cdf_at_eval_points,
136
- prob=prob,
137
- rvs=rvs,
138
- random_state=random_state,
139
- method=method,
140
- )
141
-
142
- if method == "pointwise":
143
- # these values tested elsewhere, we just make sure they're the same
144
- ecdf_lower_pointwise, ecdf_upper_pointwise = _get_pointwise_confidence_band(
145
- prob, ndraws, cdf_at_eval_points
146
- )
147
- assert np.array_equal(ecdf_lower, ecdf_lower_pointwise)
148
- assert np.array_equal(ecdf_upper, ecdf_upper_pointwise)
149
- return
150
-
151
- # check basic properties
152
- assert np.all(ecdf_lower >= 0)
153
- assert np.all(ecdf_upper <= 1)
154
- assert np.all(ecdf_lower <= ecdf_upper)
155
-
156
- # use simulation to estimate lower and upper bounds on simultaneous probability
157
- in_envelope = []
158
- random_state = np.random.default_rng(seed)
159
- for _ in range(num_trials):
160
- ecdf = _simulate_ecdf(ndraws, eval_points, dist.rvs, random_state=random_state)
161
- in_envelope.append(np.all(ecdf_lower <= ecdf) & np.all(ecdf < ecdf_upper))
162
- asymptotic_dist = scipy.stats.norm(np.mean(in_envelope), scipy.stats.sem(in_envelope))
163
- prob_lower, prob_upper = asymptotic_dist.interval(0.999)
164
-
165
- # check target probability within bounds
166
- assert prob_lower <= prob <= prob_upper
@@ -1,45 +0,0 @@
1
- # pylint: disable=redefined-outer-name, no-member
2
- import numpy as np
3
- import pytest
4
-
5
- from ...rcparams import rcParams
6
- from ...stats import r2_score, summary
7
- from ...utils import Numba
8
- from ..helpers import ( # pylint: disable=unused-import
9
- check_multiple_attrs,
10
- importorskip,
11
- multidim_models,
12
- )
13
- from .test_stats import centered_eight, non_centered_eight # pylint: disable=unused-import
14
-
15
- numba = importorskip("numba")
16
-
17
- rcParams["data.load"] = "eager"
18
-
19
-
20
- @pytest.mark.parametrize("circ_var_names", [["mu"], None])
21
- def test_summary_circ_vars(centered_eight, circ_var_names):
22
- assert summary(centered_eight, circ_var_names=circ_var_names) is not None
23
- state = Numba.numba_flag
24
- Numba.disable_numba()
25
- assert summary(centered_eight, circ_var_names=circ_var_names) is not NotImplementedError
26
- Numba.enable_numba()
27
- assert state == Numba.numba_flag
28
-
29
-
30
- def test_numba_stats():
31
- """Numba test for r2_score"""
32
- state = Numba.numba_flag # Store the current state of Numba
33
- set_1 = np.random.randn(100, 100)
34
- set_2 = np.random.randn(100, 100)
35
- set_3 = np.random.rand(100)
36
- set_4 = np.random.rand(100)
37
- Numba.disable_numba()
38
- non_numba = r2_score(set_1, set_2)
39
- non_numba_one_dimensional = r2_score(set_3, set_4)
40
- Numba.enable_numba()
41
- with_numba = r2_score(set_1, set_2)
42
- with_numba_one_dimensional = r2_score(set_3, set_4)
43
- assert state == Numba.numba_flag # Ensure that initial state = final state
44
- assert np.allclose(non_numba, with_numba)
45
- assert np.allclose(non_numba_one_dimensional, with_numba_one_dimensional)
@@ -1,384 +0,0 @@
1
- """Tests for stats_utils."""
2
-
3
- # pylint: disable=no-member
4
- import numpy as np
5
- import pytest
6
- from numpy.testing import assert_array_almost_equal
7
- from scipy.special import logsumexp
8
- from scipy.stats import circstd
9
-
10
- from ...data import from_dict, load_arviz_data
11
- from ...stats.density_utils import histogram
12
- from ...stats.stats_utils import (
13
- ELPDData,
14
- _angle,
15
- _circfunc,
16
- _circular_standard_deviation,
17
- _sqrt,
18
- get_log_likelihood,
19
- )
20
- from ...stats.stats_utils import logsumexp as _logsumexp
21
- from ...stats.stats_utils import make_ufunc, not_valid, stats_variance_2d, wrap_xarray_ufunc
22
-
23
-
24
- @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64])
25
- @pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)])
26
- @pytest.mark.parametrize("b", [None, 0, 1 / 100, 1 / 101])
27
- @pytest.mark.parametrize("keepdims", [True, False])
28
- def test_logsumexp_b(ary_dtype, axis, b, keepdims):
29
- """Test ArviZ implementation of logsumexp.
30
-
31
- Test also compares against Scipy implementation.
32
- Case where b=None, they are equal. (N=len(ary))
33
- Second case where b=x, and x is 1/(number of elements), they are almost equal.
34
-
35
- Test tests against b parameter.
36
- """
37
- ary = np.random.randn(100, 101).astype(ary_dtype) # pylint: disable=no-member
38
- assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=True) is not None
39
- ary = ary.copy()
40
- assert _logsumexp(ary=ary, axis=axis, b=b, keepdims=keepdims, copy=False) is not None
41
- out = np.empty(5)
42
- assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None
43
-
44
- # Scipy implementation
45
- scipy_results = logsumexp(ary, b=b, axis=axis, keepdims=keepdims)
46
- arviz_results = _logsumexp(ary, b=b, axis=axis, keepdims=keepdims)
47
-
48
- assert_array_almost_equal(scipy_results, arviz_results)
49
-
50
-
51
- @pytest.mark.parametrize("ary_dtype", [np.float64, np.float32, np.int32, np.int64])
52
- @pytest.mark.parametrize("axis", [None, 0, 1, (-2, -1)])
53
- @pytest.mark.parametrize("b_inv", [None, 0, 100, 101])
54
- @pytest.mark.parametrize("keepdims", [True, False])
55
- def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims):
56
- """Test ArviZ implementation of logsumexp.
57
-
58
- Test also compares against Scipy implementation.
59
- Case where b=None, they are equal. (N=len(ary))
60
- Second case where b=x, and x is 1/(number of elements), they are almost equal.
61
-
62
- Test tests against b_inv parameter.
63
- """
64
- ary = np.random.randn(100, 101).astype(ary_dtype) # pylint: disable=no-member
65
- assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=True) is not None
66
- ary = ary.copy()
67
- assert _logsumexp(ary=ary, axis=axis, b_inv=b_inv, keepdims=keepdims, copy=False) is not None
68
- out = np.empty(5)
69
- assert _logsumexp(ary=np.random.randn(10, 5), axis=0, out=out) is not None
70
-
71
- if b_inv != 0:
72
- # Scipy implementation when b_inv != 0
73
- b_scipy = 1 / b_inv if b_inv is not None else None
74
- scipy_results = logsumexp(ary, b=b_scipy, axis=axis, keepdims=keepdims)
75
- arviz_results = _logsumexp(ary, b_inv=b_inv, axis=axis, keepdims=keepdims)
76
-
77
- assert_array_almost_equal(scipy_results, arviz_results)
78
-
79
-
80
- @pytest.mark.parametrize("quantile", ((0.5,), (0.5, 0.1)))
81
- @pytest.mark.parametrize("arg", (True, False))
82
- def test_wrap_ufunc_output(quantile, arg):
83
- ary = np.random.randn(4, 100)
84
- n_output = len(quantile)
85
- if arg:
86
- res = wrap_xarray_ufunc(
87
- np.quantile, ary, ufunc_kwargs={"n_output": n_output}, func_args=(quantile,)
88
- )
89
- elif n_output == 1:
90
- res = wrap_xarray_ufunc(np.quantile, ary, func_kwargs={"q": quantile})
91
- else:
92
- res = wrap_xarray_ufunc(
93
- np.quantile, ary, ufunc_kwargs={"n_output": n_output}, func_kwargs={"q": quantile}
94
- )
95
- if n_output == 1:
96
- assert not isinstance(res, tuple)
97
- else:
98
- assert isinstance(res, tuple)
99
- assert len(res) == n_output
100
-
101
-
102
- @pytest.mark.parametrize("out_shape", ((1, 2), (1, 2, 3), (2, 3, 4, 5)))
103
- @pytest.mark.parametrize("input_dim", ((4, 100), (4, 100, 3), (4, 100, 4, 5)))
104
- def test_wrap_ufunc_out_shape(out_shape, input_dim):
105
- func = lambda x: np.random.rand(*out_shape)
106
- ary = np.ones(input_dim)
107
- res = wrap_xarray_ufunc(
108
- func, ary, func_kwargs={"out_shape": out_shape}, ufunc_kwargs={"n_dims": 1}
109
- )
110
- assert res.shape == (*ary.shape[:-1], *out_shape)
111
-
112
-
113
- def test_wrap_ufunc_out_shape_multi_input():
114
- out_shape = (2, 4)
115
- func = lambda x, y: np.random.rand(*out_shape)
116
- ary1 = np.ones((4, 100))
117
- ary2 = np.ones((4, 5))
118
- res = wrap_xarray_ufunc(
119
- func, ary1, ary2, func_kwargs={"out_shape": out_shape}, ufunc_kwargs={"n_dims": 1}
120
- )
121
- assert res.shape == (*ary1.shape[:-1], *out_shape)
122
-
123
-
124
- def test_wrap_ufunc_out_shape_multi_output_same():
125
- func = lambda x: (np.random.rand(1, 2), np.random.rand(1, 2))
126
- ary = np.ones((4, 100))
127
- res1, res2 = wrap_xarray_ufunc(
128
- func,
129
- ary,
130
- func_kwargs={"out_shape": ((1, 2), (1, 2))},
131
- ufunc_kwargs={"n_dims": 1, "n_output": 2},
132
- )
133
- assert res1.shape == (*ary.shape[:-1], 1, 2)
134
- assert res2.shape == (*ary.shape[:-1], 1, 2)
135
-
136
-
137
- def test_wrap_ufunc_out_shape_multi_output_diff():
138
- func = lambda x: (np.random.rand(5, 3), np.random.rand(10, 4))
139
- ary = np.ones((4, 100))
140
- res1, res2 = wrap_xarray_ufunc(
141
- func,
142
- ary,
143
- func_kwargs={"out_shape": ((5, 3), (10, 4))},
144
- ufunc_kwargs={"n_dims": 1, "n_output": 2},
145
- )
146
- assert res1.shape == (*ary.shape[:-1], 5, 3)
147
- assert res2.shape == (*ary.shape[:-1], 10, 4)
148
-
149
-
150
- @pytest.mark.parametrize("n_output", (1, 2, 3))
151
- def test_make_ufunc(n_output):
152
- if n_output == 3:
153
- func = lambda x: (np.mean(x), np.mean(x), np.mean(x))
154
- elif n_output == 2:
155
- func = lambda x: (np.mean(x), np.mean(x))
156
- else:
157
- func = np.mean
158
- ufunc = make_ufunc(func, n_dims=1, n_output=n_output)
159
- ary = np.ones((4, 100))
160
- res = ufunc(ary)
161
- if n_output > 1:
162
- assert all(len(res_i) == 4 for res_i in res)
163
- assert all((res_i == 1).all() for res_i in res)
164
- else:
165
- assert len(res) == 4
166
- assert (res == 1).all()
167
-
168
-
169
- @pytest.mark.parametrize("n_output", (1, 2, 3))
170
- def test_make_ufunc_out(n_output):
171
- if n_output == 3:
172
- func = lambda x: (np.mean(x), np.mean(x), np.mean(x))
173
- res = (np.empty((4,)), np.empty((4,)), np.empty((4,)))
174
- elif n_output == 2:
175
- func = lambda x: (np.mean(x), np.mean(x))
176
- res = (np.empty((4,)), np.empty((4,)))
177
- else:
178
- func = np.mean
179
- res = np.empty((4,))
180
- ufunc = make_ufunc(func, n_dims=1, n_output=n_output)
181
- ary = np.ones((4, 100))
182
- ufunc(ary, out=res)
183
- if n_output > 1:
184
- assert all(len(res_i) == 4 for res_i in res)
185
- assert all((res_i == 1).all() for res_i in res)
186
- else:
187
- assert len(res) == 4
188
- assert (res == 1).all()
189
-
190
-
191
- def test_make_ufunc_bad_ndim():
192
- with pytest.raises(TypeError):
193
- make_ufunc(np.mean, n_dims=0)
194
-
195
-
196
- @pytest.mark.parametrize("n_output", (1, 2, 3))
197
- def test_make_ufunc_out_bad(n_output):
198
- if n_output == 3:
199
- func = lambda x: (np.mean(x), np.mean(x), np.mean(x))
200
- res = (np.empty((100,)), np.empty((100,)))
201
- elif n_output == 2:
202
- func = lambda x: (np.mean(x), np.mean(x))
203
- res = np.empty((100,))
204
- else:
205
- func = np.mean
206
- res = np.empty((100,))
207
- ufunc = make_ufunc(func, n_dims=1, n_output=n_output)
208
- ary = np.ones((4, 100))
209
- with pytest.raises(TypeError):
210
- ufunc(ary, out=res)
211
-
212
-
213
- @pytest.mark.parametrize("how", ("all", "any"))
214
- def test_nan(how):
215
- assert not not_valid(np.ones(10), check_shape=False, nan_kwargs=dict(how=how))
216
- if how == "any":
217
- assert not_valid(
218
- np.concatenate((np.random.randn(100), np.full(2, np.nan))),
219
- check_shape=False,
220
- nan_kwargs=dict(how=how),
221
- )
222
- else:
223
- assert not not_valid(
224
- np.concatenate((np.random.randn(100), np.full(2, np.nan))),
225
- check_shape=False,
226
- nan_kwargs=dict(how=how),
227
- )
228
- assert not_valid(np.full(10, np.nan), check_shape=False, nan_kwargs=dict(how=how))
229
-
230
-
231
- @pytest.mark.parametrize("axis", (-1, 0, 1))
232
- def test_nan_axis(axis):
233
- data = np.random.randn(4, 100)
234
- data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
235
- axis_ = (len(data.shape) + axis) if axis < 0 else axis
236
- assert not_valid(data, check_shape=False, nan_kwargs=dict(how="any"))
237
- assert not_valid(data, check_shape=False, nan_kwargs=dict(how="any", axis=axis)).any()
238
- assert not not_valid(data, check_shape=False, nan_kwargs=dict(how="any", axis=axis)).all()
239
- assert not_valid(data, check_shape=False, nan_kwargs=dict(how="any", axis=axis)).shape == tuple(
240
- dim for ax, dim in enumerate(data.shape) if ax != axis_
241
- )
242
-
243
-
244
- def test_valid_shape():
245
- assert not not_valid(
246
- np.ones((2, 200)), check_nan=False, shape_kwargs=dict(min_chains=2, min_draws=100)
247
- )
248
- assert not not_valid(
249
- np.ones((200, 2)), check_nan=False, shape_kwargs=dict(min_chains=100, min_draws=2)
250
- )
251
- assert not_valid(
252
- np.ones((10, 10)), check_nan=False, shape_kwargs=dict(min_chains=2, min_draws=100)
253
- )
254
- assert not_valid(
255
- np.ones((10, 10)), check_nan=False, shape_kwargs=dict(min_chains=100, min_draws=2)
256
- )
257
-
258
-
259
- def test_get_log_likelihood():
260
- idata = from_dict(
261
- log_likelihood={
262
- "y1": np.random.normal(size=(4, 100, 6)),
263
- "y2": np.random.normal(size=(4, 100, 8)),
264
- }
265
- )
266
- lik1 = get_log_likelihood(idata, "y1")
267
- lik2 = get_log_likelihood(idata, "y2")
268
- assert lik1.shape == (4, 100, 6)
269
- assert lik2.shape == (4, 100, 8)
270
-
271
-
272
- def test_get_log_likelihood_warning():
273
- idata = from_dict(
274
- sample_stats={
275
- "log_likelihood": np.random.normal(size=(4, 100, 6)),
276
- }
277
- )
278
- with pytest.warns(DeprecationWarning):
279
- get_log_likelihood(idata)
280
-
281
-
282
- def test_get_log_likelihood_no_var_name():
283
- idata = from_dict(
284
- log_likelihood={
285
- "y1": np.random.normal(size=(4, 100, 6)),
286
- "y2": np.random.normal(size=(4, 100, 8)),
287
- }
288
- )
289
- with pytest.raises(TypeError, match="Found several"):
290
- get_log_likelihood(idata)
291
-
292
-
293
- def test_get_log_likelihood_no_group():
294
- idata = from_dict(
295
- posterior={
296
- "a": np.random.normal(size=(4, 100)),
297
- "b": np.random.normal(size=(4, 100)),
298
- }
299
- )
300
- with pytest.raises(TypeError, match="log likelihood not found"):
301
- get_log_likelihood(idata)
302
-
303
-
304
- def test_elpd_data_error():
305
- with pytest.raises(IndexError):
306
- repr(ELPDData(data=[0, 1, 2], index=["not IC", "se", "p"]))
307
-
308
-
309
- def test_stats_variance_1d():
310
- """Test for stats_variance_1d."""
311
- data = np.random.rand(1000000)
312
- assert np.allclose(np.var(data), stats_variance_2d(data))
313
- assert np.allclose(np.var(data, ddof=1), stats_variance_2d(data, ddof=1))
314
-
315
-
316
- def test_stats_variance_2d():
317
- """Test for stats_variance_2d."""
318
- data_1 = np.random.randn(1000, 1000)
319
- data_2 = np.random.randn(1000000)
320
- school = load_arviz_data("centered_eight").posterior["mu"].values
321
- n_school = load_arviz_data("non_centered_eight").posterior["mu"].values
322
- assert np.allclose(np.var(school, ddof=1, axis=1), stats_variance_2d(school, ddof=1, axis=1))
323
- assert np.allclose(np.var(school, ddof=1, axis=0), stats_variance_2d(school, ddof=1, axis=0))
324
- assert np.allclose(
325
- np.var(n_school, ddof=1, axis=1), stats_variance_2d(n_school, ddof=1, axis=1)
326
- )
327
- assert np.allclose(
328
- np.var(n_school, ddof=1, axis=0), stats_variance_2d(n_school, ddof=1, axis=0)
329
- )
330
- assert np.allclose(np.var(data_2), stats_variance_2d(data_2))
331
- assert np.allclose(np.var(data_2, ddof=1), stats_variance_2d(data_2, ddof=1))
332
- assert np.allclose(np.var(data_1, axis=0), stats_variance_2d(data_1, axis=0))
333
- assert np.allclose(np.var(data_1, axis=1), stats_variance_2d(data_1, axis=1))
334
- assert np.allclose(np.var(data_1, axis=0, ddof=1), stats_variance_2d(data_1, axis=0, ddof=1))
335
- assert np.allclose(np.var(data_1, axis=1, ddof=1), stats_variance_2d(data_1, axis=1, ddof=1))
336
-
337
-
338
- def test_variance_bad_data():
339
- """Test for variance when the data range is extremely wide."""
340
- data = np.array([1e20, 200e-10, 1e-17, 432e9, 2500432, 23e5, 16e-7])
341
- assert np.allclose(stats_variance_2d(data), np.var(data))
342
- assert np.allclose(stats_variance_2d(data, ddof=1), np.var(data, ddof=1))
343
- assert not np.allclose(stats_variance_2d(data), np.var(data, ddof=1))
344
-
345
-
346
- def test_histogram():
347
- school = load_arviz_data("non_centered_eight").posterior["mu"].values
348
- k_count_az, k_dens_az, _ = histogram(school, bins=np.asarray([-np.inf, 0.5, 0.7, 1, np.inf]))
349
- k_dens_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=True)
350
- k_count_np, *_ = np.histogram(school, bins=[-np.inf, 0.5, 0.7, 1, np.inf], density=False)
351
- assert np.allclose(k_count_az, k_count_np)
352
- assert np.allclose(k_dens_az, k_dens_np)
353
-
354
-
355
- def test_sqrt():
356
- x = np.random.rand(100)
357
- y = np.random.rand(100)
358
- assert np.allclose(_sqrt(x, y), np.sqrt(x + y))
359
-
360
-
361
- def test_angle():
362
- x = np.random.randn(100)
363
- high = 8
364
- low = 4
365
- res = (x - low) * 2 * np.pi / (high - low)
366
- assert np.allclose(_angle(x, low, high, np.pi), res)
367
-
368
-
369
- def test_circfunc():
370
- school = load_arviz_data("centered_eight").posterior["mu"].values
371
- a_a = _circfunc(school, 8, 4, skipna=False)
372
- assert np.allclose(a_a, _angle(school, 4, 8, np.pi))
373
-
374
-
375
- @pytest.mark.parametrize(
376
- "data", (np.random.randn(100), np.random.randn(100, 100), np.random.randn(100, 100, 100))
377
- )
378
- def test_circular_standard_deviation_1d(data):
379
- high = 8
380
- low = 4
381
- assert np.allclose(
382
- _circular_standard_deviation(data, high=high, low=low),
383
- circstd(data, high=high, low=low),
384
- )