arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,925 +0,0 @@
1
- # pylint: disable=redefined-outer-name, no-member
2
- from copy import deepcopy
3
-
4
- import numpy as np
5
- import pytest
6
- from numpy.testing import (
7
- assert_allclose,
8
- assert_array_almost_equal,
9
- assert_almost_equal,
10
- assert_array_equal,
11
- )
12
- from scipy.special import logsumexp
13
- from scipy.stats import linregress, norm, halfcauchy
14
- from xarray import DataArray, Dataset
15
- from xarray_einstats.stats import XrContinuousRV
16
-
17
- from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data, InferenceData
18
- from ...rcparams import rcParams
19
- from ...stats import (
20
- apply_test_function,
21
- bayes_factor,
22
- compare,
23
- ess,
24
- hdi,
25
- loo,
26
- loo_pit,
27
- psens,
28
- psislw,
29
- r2_score,
30
- summary,
31
- waic,
32
- weight_predictions,
33
- _calculate_ics,
34
- )
35
- from ...stats.stats import _gpinv
36
- from ...stats.stats_utils import get_log_likelihood
37
- from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import
38
-
39
- rcParams["data.load"] = "eager"
40
-
41
-
42
- @pytest.fixture(scope="session")
43
- def centered_eight():
44
- centered_eight = load_arviz_data("centered_eight")
45
- return centered_eight
46
-
47
-
48
- @pytest.fixture(scope="session")
49
- def non_centered_eight():
50
- non_centered_eight = load_arviz_data("non_centered_eight")
51
- return non_centered_eight
52
-
53
-
54
- @pytest.fixture(scope="module")
55
- def multivariable_log_likelihood(centered_eight):
56
- centered_eight = centered_eight.copy()
57
- new_arr = DataArray(
58
- np.zeros(centered_eight.log_likelihood["obs"].values.shape),
59
- dims=["chain", "draw", "school"],
60
- coords=centered_eight.log_likelihood.coords,
61
- )
62
- centered_eight.log_likelihood["decoy"] = new_arr
63
- return centered_eight
64
-
65
-
66
- def test_hdp():
67
- normal_sample = np.random.randn(5000000)
68
- interval = hdi(normal_sample)
69
- assert_array_almost_equal(interval, [-1.88, 1.88], 2)
70
-
71
-
72
- def test_hdp_2darray():
73
- normal_sample = np.random.randn(12000, 5)
74
- msg = (
75
- r"hdi currently interprets 2d data as \(draw, shape\) but this will "
76
- r"change in a future release to \(chain, draw\) for coherence with other functions"
77
- )
78
- with pytest.warns(FutureWarning, match=msg):
79
- result = hdi(normal_sample)
80
- assert result.shape == (5, 2)
81
-
82
-
83
- def test_hdi_multidimension():
84
- normal_sample = np.random.randn(12000, 10, 3)
85
- result = hdi(normal_sample)
86
- assert result.shape == (3, 2)
87
-
88
-
89
- def test_hdi_idata(centered_eight):
90
- data = centered_eight.posterior
91
- result = hdi(data)
92
- assert isinstance(result, Dataset)
93
- assert dict(result.sizes) == {"school": 8, "hdi": 2}
94
-
95
- result = hdi(data, input_core_dims=[["chain"]])
96
- assert isinstance(result, Dataset)
97
- assert result.sizes == {"draw": 500, "hdi": 2, "school": 8}
98
-
99
-
100
- def test_hdi_idata_varnames(centered_eight):
101
- data = centered_eight.posterior
102
- result = hdi(data, var_names=["mu", "theta"])
103
- assert isinstance(result, Dataset)
104
- assert result.sizes == {"hdi": 2, "school": 8}
105
- assert list(result.data_vars.keys()) == ["mu", "theta"]
106
-
107
-
108
- def test_hdi_idata_group(centered_eight):
109
- result_posterior = hdi(centered_eight, group="posterior", var_names="mu")
110
- result_prior = hdi(centered_eight, group="prior", var_names="mu")
111
- assert result_prior.sizes == {"hdi": 2}
112
- range_posterior = result_posterior.mu.values[1] - result_posterior.mu.values[0]
113
- range_prior = result_prior.mu.values[1] - result_prior.mu.values[0]
114
- assert range_posterior < range_prior
115
-
116
-
117
- def test_hdi_coords(centered_eight):
118
- data = centered_eight.posterior
119
- result = hdi(data, coords={"chain": [0, 1, 3]}, input_core_dims=[["draw"]])
120
- assert_array_equal(result.coords["chain"], [0, 1, 3])
121
-
122
-
123
- def test_hdi_multimodal():
124
- normal_sample = np.concatenate(
125
- (np.random.normal(-4, 1, 2500000), np.random.normal(2, 0.5, 2500000))
126
- )
127
- intervals = hdi(normal_sample, multimodal=True)
128
- assert_array_almost_equal(intervals, [[-5.8, -2.2], [0.9, 3.1]], 1)
129
-
130
-
131
- def test_hdi_multimodal_multivars():
132
- size = 2500000
133
- var1 = np.concatenate((np.random.normal(-4, 1, size), np.random.normal(2, 0.5, size)))
134
- var2 = np.random.normal(8, 1, size * 2)
135
- sample = Dataset(
136
- {
137
- "var1": (("chain", "draw"), var1[np.newaxis, :]),
138
- "var2": (("chain", "draw"), var2[np.newaxis, :]),
139
- },
140
- coords={"chain": [0], "draw": np.arange(size * 2)},
141
- )
142
- intervals = hdi(sample, multimodal=True)
143
- assert_array_almost_equal(intervals.var1, [[-5.8, -2.2], [0.9, 3.1]], 1)
144
- assert_array_almost_equal(intervals.var2, [[6.1, 9.9], [np.nan, np.nan]], 1)
145
-
146
-
147
- def test_hdi_circular():
148
- normal_sample = np.random.vonmises(np.pi, 1, 5000000)
149
- interval = hdi(normal_sample, circular=True)
150
- assert_array_almost_equal(interval, [0.6, -0.6], 1)
151
-
152
-
153
- def test_hdi_bad_ci():
154
- normal_sample = np.random.randn(10)
155
- with pytest.raises(ValueError):
156
- hdi(normal_sample, hdi_prob=2)
157
-
158
-
159
- def test_hdi_skipna():
160
- normal_sample = np.random.randn(500)
161
- interval = hdi(normal_sample[10:])
162
- normal_sample[:10] = np.nan
163
- interval_ = hdi(normal_sample, skipna=True)
164
- assert_array_almost_equal(interval, interval_)
165
-
166
-
167
- def test_r2_score():
168
- x = np.linspace(0, 1, 100)
169
- y = np.random.normal(x, 1)
170
- y_pred = x + np.random.randn(300, 100)
171
- res = linregress(x, y)
172
- assert_allclose(res.rvalue**2, r2_score(y, y_pred).r2, 2)
173
-
174
-
175
- @pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
176
- @pytest.mark.parametrize("multidim", [True, False])
177
- def test_compare_same(centered_eight, multidim_models, method, multidim):
178
- if multidim:
179
- data_dict = {"first": multidim_models.model_1, "second": multidim_models.model_1}
180
- else:
181
- data_dict = {"first": centered_eight, "second": centered_eight}
182
-
183
- weight = compare(data_dict, method=method)["weight"].to_numpy()
184
- assert_allclose(weight[0], weight[1])
185
- assert_allclose(np.sum(weight), 1.0)
186
-
187
-
188
- def test_compare_unknown_ic_and_method(centered_eight, non_centered_eight):
189
- model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
190
- with pytest.raises(ValueError):
191
- compare(model_dict, ic="Unknown", method="stacking")
192
- with pytest.raises(ValueError):
193
- compare(model_dict, ic="loo", method="Unknown")
194
-
195
-
196
- @pytest.mark.parametrize("ic", ["loo", "waic"])
197
- @pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
198
- @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
199
- def test_compare_different(centered_eight, non_centered_eight, ic, method, scale):
200
- model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
201
- weight = compare(model_dict, ic=ic, method=method, scale=scale)["weight"]
202
- assert weight["non_centered"] > weight["centered"]
203
- assert_allclose(np.sum(weight), 1.0)
204
-
205
-
206
- @pytest.mark.parametrize("ic", ["loo", "waic"])
207
- @pytest.mark.parametrize("method", ["stacking", "BB-pseudo-BMA", "pseudo-BMA"])
208
- def test_compare_different_multidim(multidim_models, ic, method):
209
- model_dict = {"model_1": multidim_models.model_1, "model_2": multidim_models.model_2}
210
- weight = compare(model_dict, ic=ic, method=method)["weight"]
211
-
212
- # this should hold because the same seed is always used
213
- assert weight["model_1"] > weight["model_2"]
214
- assert_allclose(np.sum(weight), 1.0)
215
-
216
-
217
- def test_compare_different_size(centered_eight, non_centered_eight):
218
- centered_eight = deepcopy(centered_eight)
219
- centered_eight.posterior = centered_eight.posterior.drop("Choate", "school")
220
- centered_eight.log_likelihood = centered_eight.log_likelihood.drop("Choate", "school")
221
- centered_eight.posterior_predictive = centered_eight.posterior_predictive.drop(
222
- "Choate", "school"
223
- )
224
- centered_eight.prior = centered_eight.prior.drop("Choate", "school")
225
- centered_eight.observed_data = centered_eight.observed_data.drop("Choate", "school")
226
- model_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
227
- with pytest.raises(ValueError):
228
- compare(model_dict, ic="waic", method="stacking")
229
-
230
-
231
- @pytest.mark.parametrize("ic", ["loo", "waic"])
232
- def test_compare_multiple_obs(multivariable_log_likelihood, centered_eight, non_centered_eight, ic):
233
- compare_dict = {
234
- "centered_eight": centered_eight,
235
- "non_centered_eight": non_centered_eight,
236
- "problematic": multivariable_log_likelihood,
237
- }
238
- with pytest.raises(TypeError, match="several log likelihood arrays"):
239
- get_log_likelihood(compare_dict["problematic"])
240
- with pytest.raises(TypeError, match="error in ELPD computation"):
241
- compare(compare_dict, ic=ic)
242
- assert compare(compare_dict, ic=ic, var_name="obs") is not None
243
-
244
-
245
- @pytest.mark.parametrize("ic", ["loo", "waic"])
246
- def test_calculate_ics(centered_eight, non_centered_eight, ic):
247
- ic_func = loo if ic == "loo" else waic
248
- idata_dict = {"centered": centered_eight, "non_centered": non_centered_eight}
249
- elpddata_dict = {key: ic_func(value) for key, value in idata_dict.items()}
250
- mixed_dict = {"centered": idata_dict["centered"], "non_centered": elpddata_dict["non_centered"]}
251
- idata_out, _, _ = _calculate_ics(idata_dict, ic=ic)
252
- elpddata_out, _, _ = _calculate_ics(elpddata_dict, ic=ic)
253
- mixed_out, _, _ = _calculate_ics(mixed_dict, ic=ic)
254
- for model in idata_dict:
255
- ic_ = f"elpd_{ic}"
256
- assert idata_out[model][ic_] == elpddata_out[model][ic_]
257
- assert idata_out[model][ic_] == mixed_out[model][ic_]
258
- assert idata_out[model][f"p_{ic}"] == elpddata_out[model][f"p_{ic}"]
259
- assert idata_out[model][f"p_{ic}"] == mixed_out[model][f"p_{ic}"]
260
-
261
-
262
- def test_calculate_ics_ic_error(centered_eight, non_centered_eight):
263
- in_dict = {"centered": loo(centered_eight), "non_centered": waic(non_centered_eight)}
264
- with pytest.raises(ValueError, match="found both loo and waic"):
265
- _calculate_ics(in_dict)
266
-
267
-
268
- def test_calculate_ics_ic_override(centered_eight, non_centered_eight):
269
- in_dict = {"centered": centered_eight, "non_centered": waic(non_centered_eight)}
270
- with pytest.warns(UserWarning, match="precomputed elpddata: waic"):
271
- out_dict, _, ic = _calculate_ics(in_dict, ic="loo")
272
- assert ic == "waic"
273
- assert out_dict["centered"]["elpd_waic"] == waic(centered_eight)["elpd_waic"]
274
-
275
-
276
- def test_summary_ndarray():
277
- array = np.random.randn(4, 100, 2)
278
- summary_df = summary(array)
279
- assert summary_df.shape
280
-
281
-
282
- @pytest.mark.parametrize("var_names_expected", ((None, 10), ("mu", 1), (["mu", "tau"], 2)))
283
- def test_summary_var_names(centered_eight, var_names_expected):
284
- var_names, expected = var_names_expected
285
- summary_df = summary(centered_eight, var_names=var_names)
286
- assert len(summary_df.index) == expected
287
-
288
-
289
- @pytest.mark.parametrize("missing_groups", (None, "posterior", "prior"))
290
- def test_summary_groups(centered_eight, missing_groups):
291
- if missing_groups == "posterior":
292
- centered_eight = deepcopy(centered_eight)
293
- del centered_eight.posterior
294
- elif missing_groups == "prior":
295
- centered_eight = deepcopy(centered_eight)
296
- del centered_eight.posterior
297
- del centered_eight.prior
298
- if missing_groups == "prior":
299
- with pytest.warns(UserWarning):
300
- summary_df = summary(centered_eight)
301
- else:
302
- summary_df = summary(centered_eight)
303
- assert summary_df.shape
304
-
305
-
306
- def test_summary_group_argument(centered_eight):
307
- summary_df_posterior = summary(centered_eight, group="posterior")
308
- summary_df_prior = summary(centered_eight, group="prior")
309
- assert list(summary_df_posterior.index) != list(summary_df_prior.index)
310
-
311
-
312
- def test_summary_wrong_group(centered_eight):
313
- with pytest.raises(TypeError, match=r"InferenceData does not contain group: InvalidGroup"):
314
- summary(centered_eight, group="InvalidGroup")
315
-
316
-
317
- METRICS_NAMES = [
318
- "mean",
319
- "sd",
320
- "hdi_3%",
321
- "hdi_97%",
322
- "mcse_mean",
323
- "mcse_sd",
324
- "ess_bulk",
325
- "ess_tail",
326
- "r_hat",
327
- "median",
328
- "mad",
329
- "eti_3%",
330
- "eti_97%",
331
- "mcse_median",
332
- "ess_median",
333
- "ess_tail",
334
- "r_hat",
335
- ]
336
-
337
-
338
- @pytest.mark.parametrize(
339
- "params",
340
- (
341
- ("mean", "all", METRICS_NAMES[:9]),
342
- ("mean", "stats", METRICS_NAMES[:4]),
343
- ("mean", "diagnostics", METRICS_NAMES[4:9]),
344
- ("median", "all", METRICS_NAMES[9:17]),
345
- ("median", "stats", METRICS_NAMES[9:13]),
346
- ("median", "diagnostics", METRICS_NAMES[13:17]),
347
- ),
348
- )
349
- def test_summary_focus_kind(centered_eight, params):
350
- stat_focus, kind, metrics_names_ = params
351
- summary_df = summary(centered_eight, stat_focus=stat_focus, kind=kind)
352
- assert_array_equal(summary_df.columns, metrics_names_)
353
-
354
-
355
- def test_summary_wrong_focus(centered_eight):
356
- with pytest.raises(TypeError, match=r"Invalid format: 'WrongFocus'.*"):
357
- summary(centered_eight, stat_focus="WrongFocus")
358
-
359
-
360
- @pytest.mark.parametrize("fmt", ["wide", "long", "xarray"])
361
- def test_summary_fmt(centered_eight, fmt):
362
- assert summary(centered_eight, fmt=fmt) is not None
363
-
364
-
365
- def test_summary_labels():
366
- coords1 = list("abcd")
367
- coords2 = np.arange(1, 6)
368
- data = from_dict(
369
- {"a": np.random.randn(4, 100, 4, 5)},
370
- coords={"dim1": coords1, "dim2": coords2},
371
- dims={"a": ["dim1", "dim2"]},
372
- )
373
- az_summary = summary(data, fmt="wide")
374
- assert az_summary is not None
375
- column_order = []
376
- for coord1 in coords1:
377
- for coord2 in coords2:
378
- column_order.append(f"a[{coord1}, {coord2}]")
379
- for col1, col2 in zip(list(az_summary.index), column_order):
380
- assert col1 == col2
381
-
382
-
383
- @pytest.mark.parametrize(
384
- "stat_funcs", [[np.var], {"var": np.var, "var2": lambda x: np.var(x) ** 2}]
385
- )
386
- def test_summary_stat_func(centered_eight, stat_funcs):
387
- arviz_summary = summary(centered_eight, stat_funcs=stat_funcs)
388
- assert arviz_summary is not None
389
- assert hasattr(arviz_summary, "var")
390
-
391
-
392
- def test_summary_nan(centered_eight):
393
- centered_eight = deepcopy(centered_eight)
394
- centered_eight.posterior["theta"].loc[{"school": "Deerfield"}] = np.nan
395
- summary_xarray = summary(centered_eight)
396
- assert summary_xarray is not None
397
- assert summary_xarray.loc["theta[Deerfield]"].isnull().all()
398
- assert (
399
- summary_xarray.loc[[ix for ix in summary_xarray.index if ix != "theta[Deerfield]"]]
400
- .notnull()
401
- .all()
402
- .all()
403
- )
404
-
405
-
406
- def test_summary_skip_nan(centered_eight):
407
- centered_eight = deepcopy(centered_eight)
408
- centered_eight.posterior["theta"].loc[{"draw": slice(10), "school": "Deerfield"}] = np.nan
409
- summary_xarray = summary(centered_eight)
410
- theta_1 = summary_xarray.loc["theta[Deerfield]"].isnull()
411
- assert summary_xarray is not None
412
- assert ~theta_1[:4].all()
413
- assert theta_1[4:].all()
414
-
415
-
416
- @pytest.mark.parametrize("fmt", [1, "bad_fmt"])
417
- def test_summary_bad_fmt(centered_eight, fmt):
418
- with pytest.raises(TypeError, match="Invalid format"):
419
- summary(centered_eight, fmt=fmt)
420
-
421
-
422
- def test_summary_order_deprecation(centered_eight):
423
- with pytest.warns(DeprecationWarning, match="order"):
424
- summary(centered_eight, order="C")
425
-
426
-
427
- def test_summary_index_origin_deprecation(centered_eight):
428
- with pytest.warns(DeprecationWarning, match="index_origin"):
429
- summary(centered_eight, index_origin=1)
430
-
431
-
432
- @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
433
- @pytest.mark.parametrize("multidim", (True, False))
434
- def test_waic(centered_eight, multidim_models, scale, multidim):
435
- """Test widely available information criterion calculation"""
436
- if multidim:
437
- assert waic(multidim_models.model_1, scale=scale) is not None
438
- waic_pointwise = waic(multidim_models.model_1, pointwise=True, scale=scale)
439
- else:
440
- assert waic(centered_eight, scale=scale) is not None
441
- waic_pointwise = waic(centered_eight, pointwise=True, scale=scale)
442
- assert waic_pointwise is not None
443
- assert "waic_i" in waic_pointwise
444
-
445
-
446
- def test_waic_bad(centered_eight):
447
- """Test widely available information criterion calculation"""
448
- centered_eight = deepcopy(centered_eight)
449
- delattr(centered_eight, "log_likelihood")
450
- with pytest.raises(TypeError):
451
- waic(centered_eight)
452
-
453
-
454
- def test_waic_bad_scale(centered_eight):
455
- """Test widely available information criterion calculation with bad scale."""
456
- with pytest.raises(TypeError):
457
- waic(centered_eight, scale="bad_value")
458
-
459
-
460
- def test_waic_warning(centered_eight):
461
- centered_eight = deepcopy(centered_eight)
462
- centered_eight.log_likelihood["obs"][:, :250, 1] = 10
463
- with pytest.warns(UserWarning):
464
- assert waic(centered_eight, pointwise=True) is not None
465
- # this should throw a warning, but due to numerical issues it fails
466
- centered_eight.log_likelihood["obs"][:, :, :] = 0
467
- with pytest.warns(UserWarning):
468
- assert waic(centered_eight, pointwise=True) is not None
469
-
470
-
471
- @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
472
- def test_waic_print(centered_eight, scale):
473
- waic_data = repr(waic(centered_eight, scale=scale))
474
- waic_pointwise = repr(waic(centered_eight, scale=scale, pointwise=True))
475
- assert waic_data is not None
476
- assert waic_pointwise is not None
477
- assert waic_data == waic_pointwise
478
-
479
-
480
- @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
481
- @pytest.mark.parametrize("multidim", (True, False))
482
- def test_loo(centered_eight, multidim_models, scale, multidim):
483
- """Test approximate leave one out criterion calculation"""
484
- if multidim:
485
- assert loo(multidim_models.model_1, scale=scale) is not None
486
- loo_pointwise = loo(multidim_models.model_1, pointwise=True, scale=scale)
487
- else:
488
- assert loo(centered_eight, scale=scale) is not None
489
- loo_pointwise = loo(centered_eight, pointwise=True, scale=scale)
490
- assert loo_pointwise is not None
491
- assert "loo_i" in loo_pointwise
492
- assert "pareto_k" in loo_pointwise
493
- assert "scale" in loo_pointwise
494
-
495
-
496
- def test_loo_one_chain(centered_eight):
497
- centered_eight = deepcopy(centered_eight)
498
- centered_eight.posterior = centered_eight.posterior.drop([1, 2, 3], "chain")
499
- centered_eight.sample_stats = centered_eight.sample_stats.drop([1, 2, 3], "chain")
500
- assert loo(centered_eight) is not None
501
-
502
-
503
- def test_loo_bad(centered_eight):
504
- with pytest.raises(TypeError):
505
- loo(np.random.randn(2, 10))
506
-
507
- centered_eight = deepcopy(centered_eight)
508
- delattr(centered_eight, "log_likelihood")
509
- with pytest.raises(TypeError):
510
- loo(centered_eight)
511
-
512
-
513
- def test_loo_bad_scale(centered_eight):
514
- """Test loo with bad scale value."""
515
- with pytest.raises(TypeError):
516
- loo(centered_eight, scale="bad_scale")
517
-
518
-
519
- def test_loo_bad_no_posterior_reff(centered_eight):
520
- loo(centered_eight, reff=None)
521
- centered_eight = deepcopy(centered_eight)
522
- del centered_eight.posterior
523
- with pytest.raises(TypeError):
524
- loo(centered_eight, reff=None)
525
- loo(centered_eight, reff=0.7)
526
-
527
-
528
- def test_loo_warning(centered_eight):
529
- centered_eight = deepcopy(centered_eight)
530
- # make one of the khats infinity
531
- centered_eight.log_likelihood["obs"][:, :, 1] = 10
532
- with pytest.warns(UserWarning) as records:
533
- assert loo(centered_eight, pointwise=True) is not None
534
- assert any("Estimated shape parameter" in str(record.message) for record in records)
535
-
536
- # make all of the khats infinity
537
- centered_eight.log_likelihood["obs"][:, :, :] = 1
538
- with pytest.warns(UserWarning) as records:
539
- assert loo(centered_eight, pointwise=True) is not None
540
- assert any("Estimated shape parameter" in str(record.message) for record in records)
541
-
542
-
543
- @pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"])
544
- def test_loo_print(centered_eight, scale):
545
- loo_data = repr(loo(centered_eight, scale=scale, pointwise=False))
546
- loo_pointwise = repr(loo(centered_eight, scale=scale, pointwise=True))
547
- assert loo_data is not None
548
- assert loo_pointwise is not None
549
- assert len(loo_data) < len(loo_pointwise)
550
-
551
-
552
- def test_psislw(centered_eight):
553
- pareto_k = loo(centered_eight, pointwise=True, reff=0.7)["pareto_k"]
554
- log_likelihood = get_log_likelihood(centered_eight)
555
- log_likelihood = log_likelihood.stack(__sample__=("chain", "draw"))
556
- assert_allclose(pareto_k, psislw(-log_likelihood, 0.7)[1])
557
-
558
-
559
- def test_psislw_smooths_for_low_k():
560
- # check that log-weights are smoothed even when k < 1/3
561
- # https://github.com/arviz-devs/arviz/issues/2010
562
- rng = np.random.default_rng(44)
563
- x = rng.normal(size=100)
564
- x_smoothed, k = psislw(x.copy())
565
- assert k < 1 / 3
566
- assert not np.allclose(x - logsumexp(x), x_smoothed)
567
-
568
-
569
- @pytest.mark.parametrize("probs", [True, False])
570
- @pytest.mark.parametrize("kappa", [-1, -0.5, 1e-30, 0.5, 1])
571
- @pytest.mark.parametrize("sigma", [0, 2])
572
- def test_gpinv(probs, kappa, sigma):
573
- if probs:
574
- probs = np.array([0.1, 0.1, 0.1, 0.2, 0.3])
575
- else:
576
- probs = np.array([-0.1, 0.1, 0.1, 0.2, 0.3])
577
- assert len(_gpinv(probs, kappa, sigma)) == len(probs)
578
-
579
-
580
- @pytest.mark.parametrize("func", [loo, waic])
581
- def test_multidimensional_log_likelihood(func):
582
- llm = np.random.rand(4, 23, 15, 2)
583
- ll1 = llm.reshape(4, 23, 15 * 2)
584
- statsm = Dataset(dict(log_likelihood=DataArray(llm, dims=["chain", "draw", "a", "b"])))
585
-
586
- stats1 = Dataset(dict(log_likelihood=DataArray(ll1, dims=["chain", "draw", "v"])))
587
-
588
- post = Dataset(dict(mu=DataArray(np.random.rand(4, 23, 2), dims=["chain", "draw", "v"])))
589
-
590
- dsm = convert_to_inference_data(statsm, group="sample_stats")
591
- ds1 = convert_to_inference_data(stats1, group="sample_stats")
592
- dsp = convert_to_inference_data(post, group="posterior")
593
-
594
- dsm = concat(dsp, dsm)
595
- ds1 = concat(dsp, ds1)
596
-
597
- frm = func(dsm)
598
- fr1 = func(ds1)
599
-
600
- assert all(
601
- fr1[key] == frm[key] for key in fr1.index if key not in {"loo_i", "waic_i", "pareto_k"}
602
- )
603
- assert_array_almost_equal(frm[:4], fr1[:4])
604
-
605
-
606
- @pytest.mark.parametrize(
607
- "args",
608
- [
609
- {"y": "obs"},
610
- {"y": "obs", "y_hat": "obs"},
611
- {"y": "arr", "y_hat": "obs"},
612
- {"y": "obs", "y_hat": "arr"},
613
- {"y": "arr", "y_hat": "arr"},
614
- {"y": "obs", "y_hat": "obs", "log_weights": "arr"},
615
- {"y": "arr", "y_hat": "obs", "log_weights": "arr"},
616
- {"y": "obs", "y_hat": "arr", "log_weights": "arr"},
617
- {"idata": False},
618
- ],
619
- )
620
- def test_loo_pit(centered_eight, args):
621
- y = args.get("y", None)
622
- y_hat = args.get("y_hat", None)
623
- log_weights = args.get("log_weights", None)
624
- y_arr = centered_eight.observed_data.obs
625
- y_hat_arr = centered_eight.posterior_predictive.obs.stack(__sample__=("chain", "draw"))
626
- log_like = get_log_likelihood(centered_eight).stack(__sample__=("chain", "draw"))
627
- n_samples = len(log_like.__sample__)
628
- ess_p = ess(centered_eight.posterior, method="mean")
629
- reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
630
- log_weights_arr = psislw(-log_like, reff=reff)[0]
631
-
632
- if args.get("idata", True):
633
- if y == "arr":
634
- y = y_arr
635
- if y_hat == "arr":
636
- y_hat = y_hat_arr
637
- if log_weights == "arr":
638
- log_weights = log_weights_arr
639
- loo_pit_data = loo_pit(idata=centered_eight, y=y, y_hat=y_hat, log_weights=log_weights)
640
- else:
641
- loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
642
- assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
643
-
644
-
645
- @pytest.mark.parametrize(
646
- "args",
647
- [
648
- {"y": "y"},
649
- {"y": "y", "y_hat": "y"},
650
- {"y": "arr", "y_hat": "y"},
651
- {"y": "y", "y_hat": "arr"},
652
- {"y": "arr", "y_hat": "arr"},
653
- {"y": "y", "y_hat": "y", "log_weights": "arr"},
654
- {"y": "arr", "y_hat": "y", "log_weights": "arr"},
655
- {"y": "y", "y_hat": "arr", "log_weights": "arr"},
656
- {"idata": False},
657
- ],
658
- )
659
- def test_loo_pit_multidim(multidim_models, args):
660
- y = args.get("y", None)
661
- y_hat = args.get("y_hat", None)
662
- log_weights = args.get("log_weights", None)
663
- idata = multidim_models.model_1
664
- y_arr = idata.observed_data.y
665
- y_hat_arr = idata.posterior_predictive.y.stack(__sample__=("chain", "draw"))
666
- log_like = get_log_likelihood(idata).stack(__sample__=("chain", "draw"))
667
- n_samples = len(log_like.__sample__)
668
- ess_p = ess(idata.posterior, method="mean")
669
- reff = np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
670
- log_weights_arr = psislw(-log_like, reff=reff)[0]
671
-
672
- if args.get("idata", True):
673
- if y == "arr":
674
- y = y_arr
675
- if y_hat == "arr":
676
- y_hat = y_hat_arr
677
- if log_weights == "arr":
678
- log_weights = log_weights_arr
679
- loo_pit_data = loo_pit(idata=idata, y=y, y_hat=y_hat, log_weights=log_weights)
680
- else:
681
- loo_pit_data = loo_pit(idata=None, y=y_arr, y_hat=y_hat_arr, log_weights=log_weights_arr)
682
- assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
683
-
684
-
685
- def test_loo_pit_multi_lik():
686
- rng = np.random.default_rng(0)
687
- post_pred = rng.standard_normal(size=(4, 100, 10))
688
- obs = np.quantile(post_pred, np.linspace(0, 1, 10))
689
- obs[0] *= 0.9
690
- obs[-1] *= 1.1
691
- idata = from_dict(
692
- posterior={"a": np.random.randn(4, 100)},
693
- posterior_predictive={"y": post_pred},
694
- observed_data={"y": obs},
695
- log_likelihood={"y": -(post_pred**2), "decoy": np.zeros_like(post_pred)},
696
- )
697
- loo_pit_data = loo_pit(idata, y="y")
698
- assert np.all((loo_pit_data >= 0) & (loo_pit_data <= 1))
699
-
700
-
701
- @pytest.mark.parametrize("input_type", ["idataarray", "idatanone_ystr", "yarr_yhatnone"])
702
- def test_loo_pit_bad_input(centered_eight, input_type):
703
- """Test incompatible input combinations."""
704
- arr = np.random.random((8, 200))
705
- if input_type == "idataarray":
706
- with pytest.raises(ValueError, match=r"type InferenceData or None"):
707
- loo_pit(idata=arr, y="obs")
708
- elif input_type == "idatanone_ystr":
709
- with pytest.raises(ValueError, match=r"all 3.+must be array or DataArray"):
710
- loo_pit(idata=None, y="obs")
711
- elif input_type == "yarr_yhatnone":
712
- with pytest.raises(ValueError, match=r"y_hat.+None.+y.+str"):
713
- loo_pit(idata=centered_eight, y=arr, y_hat=None)
714
-
715
-
716
- @pytest.mark.parametrize("arg", ["y", "y_hat", "log_weights"])
717
- def test_loo_pit_bad_input_type(centered_eight, arg):
718
- """Test wrong input type (not None, str not DataArray."""
719
- kwargs = {"y": "obs", "y_hat": "obs", "log_weights": None}
720
- kwargs[arg] = 2 # use int instead of array-like
721
- with pytest.raises(ValueError, match=f"not {type(2)}"):
722
- loo_pit(idata=centered_eight, **kwargs)
723
-
724
-
725
- @pytest.mark.parametrize("incompatibility", ["y-y_hat1", "y-y_hat2", "y_hat-log_weights"])
726
- def test_loo_pit_bad_input_shape(incompatibility):
727
- """Test shape incompatibilities."""
728
- y = np.random.random(8)
729
- y_hat = np.random.random((8, 200))
730
- log_weights = np.random.random((8, 200))
731
- if incompatibility == "y-y_hat1":
732
- with pytest.raises(ValueError, match="1 more dimension"):
733
- loo_pit(y=y, y_hat=y_hat[None, :], log_weights=log_weights)
734
- elif incompatibility == "y-y_hat2":
735
- with pytest.raises(ValueError, match="y has shape"):
736
- loo_pit(y=y, y_hat=y_hat[1:3, :], log_weights=log_weights)
737
- elif incompatibility == "y_hat-log_weights":
738
- with pytest.raises(ValueError, match="must have the same shape"):
739
- loo_pit(y=y, y_hat=y_hat[:, :100], log_weights=log_weights)
740
-
741
-
742
- @pytest.mark.parametrize("pointwise", [True, False])
743
- @pytest.mark.parametrize("inplace", [True, False])
744
- @pytest.mark.parametrize(
745
- "kwargs",
746
- [
747
- {},
748
- {"group": "posterior_predictive", "var_names": {"posterior_predictive": "obs"}},
749
- {"group": "observed_data", "var_names": {"both": "obs"}, "out_data_shape": "shape"},
750
- {"var_names": {"both": "obs", "posterior": ["theta", "mu"]}},
751
- {"group": "observed_data", "out_name_data": "T_name"},
752
- ],
753
- )
754
- def test_apply_test_function(centered_eight, pointwise, inplace, kwargs):
755
- """Test some usual call cases of apply_test_function"""
756
- centered_eight = deepcopy(centered_eight)
757
- group = kwargs.get("group", "both")
758
- var_names = kwargs.get("var_names", None)
759
- out_data_shape = kwargs.get("out_data_shape", None)
760
- out_pp_shape = kwargs.get("out_pp_shape", None)
761
- out_name_data = kwargs.get("out_name_data", "T")
762
- if out_data_shape == "shape":
763
- out_data_shape = (8,) if pointwise else ()
764
- if out_pp_shape == "shape":
765
- out_pp_shape = (4, 500, 8) if pointwise else (4, 500)
766
- idata = deepcopy(centered_eight)
767
- idata_out = apply_test_function(
768
- idata,
769
- lambda y, theta: np.mean(y),
770
- group=group,
771
- var_names=var_names,
772
- pointwise=pointwise,
773
- out_name_data=out_name_data,
774
- out_data_shape=out_data_shape,
775
- out_pp_shape=out_pp_shape,
776
- )
777
- if inplace:
778
- assert idata is idata_out
779
-
780
- if group == "both":
781
- test_dict = {"observed_data": ["T"], "posterior_predictive": ["T"]}
782
- else:
783
- test_dict = {group: [kwargs.get("out_name_data", "T")]}
784
-
785
- fails = check_multiple_attrs(test_dict, idata_out)
786
- assert not fails
787
-
788
-
789
- def test_apply_test_function_bad_group(centered_eight):
790
- """Test error when group is an invalid name."""
791
- with pytest.raises(ValueError, match="Invalid group argument"):
792
- apply_test_function(centered_eight, lambda y, theta: y, group="bad_group")
793
-
794
-
795
- def test_apply_test_function_missing_group():
796
- """Test error when InferenceData object is missing a required group.
797
-
798
- The function cannot work if group="both" but InferenceData object has no
799
- posterior_predictive group.
800
- """
801
- idata = from_dict(
802
- posterior={"a": np.random.random((4, 500, 30))}, observed_data={"y": np.random.random(30)}
803
- )
804
- with pytest.raises(ValueError, match="must have posterior_predictive"):
805
- apply_test_function(idata, lambda y, theta: np.mean, group="both")
806
-
807
-
808
- def test_apply_test_function_should_overwrite_error(centered_eight):
809
- """Test error when overwrite=False but out_name is already a present variable."""
810
- with pytest.raises(ValueError, match="Should overwrite"):
811
- apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs")
812
-
813
-
814
- def test_weight_predictions():
815
- idata0 = from_dict(
816
- posterior_predictive={"a": np.random.normal(-1, 1, 1000)}, observed_data={"a": [1]}
817
- )
818
- idata1 = from_dict(
819
- posterior_predictive={"a": np.random.normal(1, 1, 1000)}, observed_data={"a": [1]}
820
- )
821
-
822
- new = weight_predictions([idata0, idata1])
823
- assert (
824
- idata1.posterior_predictive.mean()
825
- > new.posterior_predictive.mean()
826
- > idata0.posterior_predictive.mean()
827
- )
828
- assert "posterior_predictive" in new
829
- assert "observed_data" in new
830
-
831
- new = weight_predictions([idata0, idata1], weights=[0.5, 0.5])
832
- assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
833
- new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
834
- assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)
835
-
836
-
837
- @pytest.fixture(scope="module")
838
- def psens_data():
839
- non_centered_eight = load_arviz_data("non_centered_eight")
840
- post = non_centered_eight.posterior
841
- log_prior = {
842
- "mu": XrContinuousRV(norm, 0, 5).logpdf(post["mu"]),
843
- "tau": XrContinuousRV(halfcauchy, scale=5).logpdf(post["tau"]),
844
- "theta_t": XrContinuousRV(norm, 0, 1).logpdf(post["theta_t"]),
845
- }
846
- non_centered_eight.add_groups({"log_prior": log_prior})
847
- return non_centered_eight
848
-
849
-
850
- @pytest.mark.parametrize("component", ("prior", "likelihood"))
851
- def test_priorsens_global(psens_data, component):
852
- result = psens(psens_data, component=component)
853
- assert "mu" in result
854
- assert "theta" in result
855
- assert "school" in result.theta_t.dims
856
-
857
-
858
- def test_priorsens_var_names(psens_data):
859
- result1 = psens(
860
- psens_data, component="prior", component_var_names=["mu", "tau"], var_names=["mu", "tau"]
861
- )
862
- result2 = psens(psens_data, component="prior", var_names=["mu", "tau"])
863
- for result in (result1, result2):
864
- assert "theta" not in result
865
- assert "mu" in result
866
- assert "tau" in result
867
- assert not np.isclose(result1.mu, result2.mu)
868
-
869
-
870
- def test_priorsens_coords(psens_data):
871
- result = psens(psens_data, component="likelihood", component_coords={"school": "Choate"})
872
- assert "mu" in result
873
- assert "theta" in result
874
- assert "school" in result.theta_t.dims
875
-
876
-
877
- def test_bayes_factor():
878
- idata = from_dict(
879
- posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
880
- )
881
- bf_dict0 = bayes_factor(idata, var_name="a", ref_val=0)
882
- bf_dict1 = bayes_factor(idata, prior=np.random.normal(0, 10, 5000), var_name="a", ref_val=0)
883
- assert bf_dict0["BF10"] > bf_dict0["BF01"]
884
- assert bf_dict1["BF10"] < bf_dict1["BF01"]
885
-
886
-
887
- def test_compare_sorting_consistency():
888
- chains, draws = 4, 1000
889
-
890
- # Model 1 - good fit
891
- log_lik1 = np.random.normal(-2, 1, size=(chains, draws))
892
- posterior1 = Dataset(
893
- {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
894
- coords={"chain": range(chains), "draw": range(draws)},
895
- )
896
- log_like1 = Dataset(
897
- {"y": (("chain", "draw"), log_lik1)},
898
- coords={"chain": range(chains), "draw": range(draws)},
899
- )
900
- data1 = InferenceData(posterior=posterior1, log_likelihood=log_like1)
901
-
902
- # Model 2 - poor fit (higher variance)
903
- log_lik2 = np.random.normal(-5, 2, size=(chains, draws))
904
- posterior2 = Dataset(
905
- {"theta": (("chain", "draw"), np.random.normal(0, 1, size=(chains, draws)))},
906
- coords={"chain": range(chains), "draw": range(draws)},
907
- )
908
- log_like2 = Dataset(
909
- {"y": (("chain", "draw"), log_lik2)},
910
- coords={"chain": range(chains), "draw": range(draws)},
911
- )
912
- data2 = InferenceData(posterior=posterior2, log_likelihood=log_like2)
913
-
914
- # Compare models in different orders
915
- comp_dict1 = {"M1": data1, "M2": data2}
916
- comp_dict2 = {"M2": data2, "M1": data1}
917
-
918
- comparison1 = compare(comp_dict1, method="bb-pseudo-bma")
919
- comparison2 = compare(comp_dict2, method="bb-pseudo-bma")
920
-
921
- assert comparison1.index.tolist() == comparison2.index.tolist()
922
-
923
- se1 = comparison1["se"].values
924
- se2 = comparison2["se"].values
925
- np.testing.assert_array_almost_equal(se1, se2)