arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,2197 +0,0 @@
1
- """Tests use the default backend."""
2
-
3
- # pylint: disable=redefined-outer-name,too-many-lines
4
- import os
5
- import re
6
- from copy import deepcopy
7
-
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- import pytest
11
- import xarray as xr
12
- from matplotlib import animation
13
- from pandas import DataFrame
14
- from scipy.stats import gaussian_kde, norm
15
-
16
- from ...data import from_dict, load_arviz_data
17
- from ...labels import MapLabeller
18
- from ...plots import (
19
- plot_autocorr,
20
- plot_bf,
21
- plot_bpv,
22
- plot_compare,
23
- plot_density,
24
- plot_dist,
25
- plot_dist_comparison,
26
- plot_dot,
27
- plot_ecdf,
28
- plot_elpd,
29
- plot_energy,
30
- plot_ess,
31
- plot_forest,
32
- plot_hdi,
33
- plot_kde,
34
- plot_khat,
35
- plot_lm,
36
- plot_loo_pit,
37
- plot_mcse,
38
- plot_pair,
39
- plot_parallel,
40
- plot_posterior,
41
- plot_ppc,
42
- plot_rank,
43
- plot_separation,
44
- plot_trace,
45
- plot_ts,
46
- plot_violin,
47
- )
48
- from ...plots.dotplot import wilkinson_algorithm
49
- from ...plots.plot_utils import plot_point_interval
50
- from ...rcparams import rc_context, rcParams
51
- from ...stats import compare, hdi, loo, waic
52
- from ...stats.density_utils import kde as _kde
53
- from ...utils import BehaviourChangeWarning, _cov
54
- from ..helpers import ( # pylint: disable=unused-import
55
- RandomVariableTestClass,
56
- create_model,
57
- create_multidimensional_model,
58
- does_not_warn,
59
- eight_schools_params,
60
- models,
61
- multidim_models,
62
- )
63
-
64
- rcParams["data.load"] = "eager"
65
-
66
-
67
- @pytest.fixture(scope="function", autouse=True)
68
- def clean_plots(request, save_figs):
69
- """Close plots after each test, optionally save if --save is specified during test invocation"""
70
-
71
- def fin():
72
- if save_figs is not None:
73
- plt.savefig(f"{os.path.join(save_figs, request.node.name)}.png")
74
- plt.close("all")
75
-
76
- request.addfinalizer(fin)
77
-
78
-
79
- @pytest.fixture(scope="module")
80
- def data(eight_schools_params):
81
- data = eight_schools_params
82
- return data
83
-
84
-
85
- @pytest.fixture(scope="module")
86
- def df_trace():
87
- return DataFrame({"a": np.random.poisson(2.3, 100)})
88
-
89
-
90
- @pytest.fixture(scope="module")
91
- def discrete_model():
92
- """Simple fixture for random discrete model"""
93
- return {"x": np.random.randint(10, size=100), "y": np.random.randint(10, size=100)}
94
-
95
-
96
- @pytest.fixture(scope="module")
97
- def discrete_multidim_model():
98
- """Simple fixture for random discrete model"""
99
- idata = from_dict(
100
- {"x": np.random.randint(10, size=(2, 50, 3)), "y": np.random.randint(10, size=(2, 50))},
101
- dims={"x": ["school"]},
102
- )
103
- return idata
104
-
105
-
106
- @pytest.fixture(scope="module")
107
- def continuous_model():
108
- """Simple fixture for random continuous model"""
109
- return {"x": np.random.beta(2, 5, size=100), "y": np.random.beta(2, 5, size=100)}
110
-
111
-
112
- @pytest.fixture(scope="function")
113
- def fig_ax():
114
- fig, ax = plt.subplots(1, 1)
115
- return fig, ax
116
-
117
-
118
- @pytest.fixture(scope="module")
119
- def data_random():
120
- return np.random.randint(1, 100, size=20)
121
-
122
-
123
- @pytest.fixture(scope="module")
124
- def data_list():
125
- return list(range(11, 31))
126
-
127
-
128
- @pytest.mark.parametrize(
129
- "kwargs",
130
- [
131
- {"point_estimate": "mean"},
132
- {"point_estimate": "median"},
133
- {"hdi_prob": 0.94},
134
- {"hdi_prob": 1},
135
- {"outline": True},
136
- {"colors": ["g", "b", "r", "y"]},
137
- {"colors": "k"},
138
- {"hdi_markers": ["v"]},
139
- {"shade": 1},
140
- {"transform": lambda x: x + 1},
141
- {"ax": plt.subplots(6, 3)[1]},
142
- ],
143
- )
144
- def test_plot_density_float(models, kwargs):
145
- obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
146
- axes = plot_density(obj, **kwargs)
147
- assert axes.shape == (6, 3)
148
-
149
-
150
- def test_plot_density_discrete(discrete_model):
151
- axes = plot_density(discrete_model, shade=0.9)
152
- assert axes.size == 2
153
-
154
-
155
- def test_plot_density_no_subset():
156
- """Test plot_density works when variables are not subset of one another (#1093)."""
157
- model_ab = from_dict(
158
- {
159
- "a": np.random.normal(size=200),
160
- "b": np.random.normal(size=200),
161
- }
162
- )
163
- model_bc = from_dict(
164
- {
165
- "b": np.random.normal(size=200),
166
- "c": np.random.normal(size=200),
167
- }
168
- )
169
- axes = plot_density([model_ab, model_bc])
170
- assert axes.size == 3
171
-
172
-
173
- def test_plot_density_nonstring_varnames():
174
- """Test plot_density works when variables are not strings."""
175
- rv1 = RandomVariableTestClass("a")
176
- rv2 = RandomVariableTestClass("b")
177
- rv3 = RandomVariableTestClass("c")
178
- model_ab = from_dict(
179
- {
180
- rv1: np.random.normal(size=200),
181
- rv2: np.random.normal(size=200),
182
- }
183
- )
184
- model_bc = from_dict(
185
- {
186
- rv2: np.random.normal(size=200),
187
- rv3: np.random.normal(size=200),
188
- }
189
- )
190
- axes = plot_density([model_ab, model_bc])
191
- assert axes.size == 3
192
-
193
-
194
- def test_plot_density_bad_kwargs(models):
195
- obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
196
- with pytest.raises(ValueError):
197
- plot_density(obj, point_estimate="bad_value")
198
-
199
- with pytest.raises(ValueError):
200
- plot_density(obj, data_labels=[f"bad_value_{i}" for i in range(len(obj) + 10)])
201
-
202
- with pytest.raises(ValueError):
203
- plot_density(obj, hdi_prob=2)
204
-
205
- with pytest.raises(ValueError):
206
- plot_density(obj, filter_vars="bad_value")
207
-
208
-
209
- def test_plot_density_discrete_combinedims(discrete_model):
210
- axes = plot_density(discrete_model, combine_dims={"school"}, shade=0.9)
211
- assert axes.size == 2
212
-
213
-
214
- @pytest.mark.parametrize(
215
- "kwargs",
216
- [
217
- {},
218
- {"y_hat_line": True},
219
- {"expected_events": True},
220
- {"y_hat_line_kwargs": {"linestyle": "dotted"}},
221
- {"exp_events_kwargs": {"marker": "o"}},
222
- ],
223
- )
224
- def test_plot_separation(kwargs):
225
- idata = load_arviz_data("classification10d")
226
- ax = plot_separation(idata=idata, y="outcome", **kwargs)
227
- assert ax
228
-
229
-
230
- @pytest.mark.parametrize(
231
- "kwargs",
232
- [
233
- {},
234
- {"var_names": "mu"},
235
- {"var_names": ["mu", "tau"]},
236
- {"combined": True},
237
- {"compact": True},
238
- {"combined": True, "compact": True, "legend": True},
239
- {"divergences": "top", "legend": True},
240
- {"divergences": False},
241
- {"kind": "rank_vlines"},
242
- {"kind": "rank_bars"},
243
- {"lines": [("mu", {}, [1, 2])]},
244
- {"lines": [("mu", {}, 8)]},
245
- {"circ_var_names": ["mu"]},
246
- {"circ_var_names": ["mu"], "circ_var_units": "degrees"},
247
- ],
248
- )
249
- def test_plot_trace(models, kwargs):
250
- axes = plot_trace(models.model_1, **kwargs)
251
- assert axes.shape
252
-
253
-
254
- @pytest.mark.parametrize(
255
- "compact",
256
- [True, False],
257
- )
258
- @pytest.mark.parametrize(
259
- "combined",
260
- [True, False],
261
- )
262
- def test_plot_trace_legend(compact, combined):
263
- idata = load_arviz_data("rugby")
264
- axes = plot_trace(
265
- idata, var_names=["home", "atts_star"], compact=compact, combined=combined, legend=True
266
- )
267
- assert axes[0, 1].get_legend()
268
- compact_legend = axes[1, 0].get_legend()
269
- if compact:
270
- assert axes.shape == (2, 2)
271
- assert compact_legend
272
- else:
273
- assert axes.shape == (7, 2)
274
- assert not compact_legend
275
-
276
-
277
- def test_plot_trace_discrete(discrete_model):
278
- axes = plot_trace(discrete_model)
279
- assert axes.shape
280
-
281
-
282
- def test_plot_trace_max_subplots_warning(models):
283
- with pytest.warns(UserWarning):
284
- with rc_context(rc={"plot.max_subplots": 6}):
285
- axes = plot_trace(models.model_1)
286
- assert axes.shape == (3, 2)
287
-
288
-
289
- def test_plot_dist_comparison_warning(models):
290
- with pytest.warns(UserWarning):
291
- with rc_context(rc={"plot.max_subplots": 6}):
292
- axes = plot_dist_comparison(models.model_1)
293
- assert axes.shape == (2, 3)
294
-
295
-
296
- @pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "tau"], "lines": [("hey", {}, [1])]}])
297
- def test_plot_trace_invalid_varname_warning(models, kwargs):
298
- with pytest.warns(UserWarning, match="valid var.+should be provided"):
299
- axes = plot_trace(models.model_1, **kwargs)
300
- assert axes.shape
301
-
302
-
303
- def test_plot_trace_diverging_correctly_transposed():
304
- idata = load_arviz_data("centered_eight")
305
- idata.sample_stats["diverging"] = idata.sample_stats.diverging.T
306
- plot_trace(idata, divergences="bottom")
307
-
308
-
309
- @pytest.mark.parametrize(
310
- "bad_kwargs", [{"var_names": ["mu", "tau"], "lines": [("mu", {}, ["hey"])]}]
311
- )
312
- def test_plot_trace_bad_lines_value(models, bad_kwargs):
313
- with pytest.raises(ValueError, match="line-positions should be numeric"):
314
- plot_trace(models.model_1, **bad_kwargs)
315
-
316
-
317
- @pytest.mark.parametrize("prop", ["chain_prop", "compact_prop"])
318
- def test_plot_trace_futurewarning(models, prop):
319
- with pytest.warns(FutureWarning, match=f"{prop} as a tuple.+deprecated"):
320
- ax = plot_trace(models.model_1, **{prop: ("ls", ("-", "--"))})
321
- assert ax.shape
322
-
323
-
324
- @pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
325
- @pytest.mark.parametrize(
326
- "args_expected",
327
- [
328
- ({}, 1),
329
- ({"var_names": "mu", "transform": lambda x: x + 1}, 1),
330
- ({"var_names": "mu", "rope": (-1, 1), "combine_dims": {"school"}}, 1),
331
- ({"r_hat": True, "quartiles": False}, 2),
332
- ({"var_names": ["mu"], "colors": "C0", "ess": True, "combined": True}, 2),
333
- (
334
- {
335
- "kind": "ridgeplot",
336
- "ridgeplot_truncate": False,
337
- "ridgeplot_quantiles": [0.25, 0.5, 0.75],
338
- },
339
- 1,
340
- ),
341
- ({"kind": "ridgeplot", "r_hat": True, "ess": True}, 3),
342
- ({"kind": "ridgeplot", "r_hat": True, "ess": True}, 3),
343
- ({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3),
344
- (
345
- {
346
- "var_names": ["mu", "theta"],
347
- "rope": {
348
- "mu": [{"rope": (-0.1, 0.1)}],
349
- "theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
350
- },
351
- },
352
- 1,
353
- ),
354
- ],
355
- )
356
- def test_plot_forest(models, model_fits, args_expected):
357
- obj = [getattr(models, model_fit) for model_fit in model_fits]
358
- args, expected = args_expected
359
- axes = plot_forest(obj, **args)
360
- assert axes.size == expected
361
-
362
-
363
- def test_plot_forest_rope_exception():
364
- with pytest.raises(ValueError) as err:
365
- plot_forest({"x": [1]}, rope="not_correct_format")
366
- assert "Argument `rope` must be None, a dictionary like" in str(err.value)
367
-
368
-
369
- def test_plot_forest_single_value():
370
- axes = plot_forest({"x": [1]})
371
- assert axes.shape
372
-
373
-
374
- def test_plot_forest_ridge_discrete(discrete_model):
375
- axes = plot_forest(discrete_model, kind="ridgeplot")
376
- assert axes.shape
377
-
378
-
379
- @pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
380
- def test_plot_forest_bad(models, model_fits):
381
- obj = [getattr(models, model_fit) for model_fit in model_fits]
382
- with pytest.raises(TypeError):
383
- plot_forest(obj, kind="bad_kind")
384
-
385
- with pytest.raises(ValueError):
386
- plot_forest(obj, model_names=[f"model_name_{i}" for i in range(len(obj) + 10)])
387
-
388
-
389
- @pytest.mark.parametrize("kind", ["kde", "hist"])
390
- def test_plot_energy(models, kind):
391
- assert plot_energy(models.model_1, kind=kind)
392
-
393
-
394
- def test_plot_energy_bad(models):
395
- with pytest.raises(ValueError):
396
- plot_energy(models.model_1, kind="bad_kind")
397
-
398
-
399
- def test_plot_energy_correctly_transposed():
400
- idata = load_arviz_data("centered_eight")
401
- idata.sample_stats["energy"] = idata.sample_stats.energy.T
402
- ax = plot_energy(idata)
403
- # legend has one entry for each KDE and 1 BFMI for each chain
404
- assert len(ax.legend_.texts) == 2 + len(idata.sample_stats.chain)
405
-
406
-
407
- def test_plot_parallel_raises_valueerror(df_trace): # pylint: disable=invalid-name
408
- with pytest.raises(ValueError):
409
- plot_parallel(df_trace)
410
-
411
-
412
- @pytest.mark.parametrize("norm_method", [None, "normal", "minmax", "rank"])
413
- def test_plot_parallel(models, norm_method):
414
- assert plot_parallel(models.model_1, var_names=["mu", "tau"], norm_method=norm_method)
415
-
416
-
417
- @pytest.mark.parametrize("var_names", [None, "mu", ["mu", "tau"]])
418
- def test_plot_parallel_exception(models, var_names):
419
- """Ensure that correct exception is raised when one variable is passed."""
420
- with pytest.raises(ValueError):
421
- assert plot_parallel(models.model_1, var_names=var_names, norm_method="foo")
422
-
423
-
424
- @pytest.mark.parametrize(
425
- "kwargs",
426
- [
427
- {"plot_kwargs": {"linestyle": "-"}},
428
- {"contour": True, "fill_last": False},
429
- {
430
- "contour": True,
431
- "contourf_kwargs": {"cmap": "plasma"},
432
- "contour_kwargs": {"linewidths": 1},
433
- },
434
- {"contour": False},
435
- {"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}},
436
- {"is_circular": False},
437
- {"is_circular": True},
438
- {"is_circular": "radians"},
439
- {"is_circular": "degrees"},
440
- {"adaptive": True},
441
- {"hdi_probs": [0.3, 0.9, 0.6]},
442
- {"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
443
- {"hdi_probs": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
444
- ],
445
- )
446
- def test_plot_kde(continuous_model, kwargs):
447
- axes = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
448
- axes1 = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
449
- assert axes
450
- assert axes is axes1
451
-
452
-
453
- @pytest.mark.parametrize(
454
- "kwargs",
455
- [
456
- {"hdi_probs": [1, 2, 3]},
457
- {"hdi_probs": [-0.3, 0.6, 0.9]},
458
- {"hdi_probs": [0, 0.3, 0.6]},
459
- {"hdi_probs": [0.3, 0.6, 1]},
460
- ],
461
- )
462
- def test_plot_kde_hdi_probs_bad(continuous_model, kwargs):
463
- """Ensure invalid hdi probabilities are rejected."""
464
- with pytest.raises(ValueError):
465
- plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
466
-
467
-
468
- @pytest.mark.parametrize(
469
- "kwargs",
470
- [
471
- {"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"levels": [0, 0.5, 1]}},
472
- {"hdi_probs": [0.3, 0.6, 0.9], "contour_kwargs": {"levels": [0, 0.5, 1]}},
473
- ],
474
- )
475
- def test_plot_kde_hdi_probs_warning(continuous_model, kwargs):
476
- """Ensure warning is raised when too many keywords are specified."""
477
- with pytest.warns(UserWarning):
478
- axes = plot_kde(continuous_model["x"], continuous_model["y"], **kwargs)
479
- assert axes
480
-
481
-
482
- @pytest.mark.parametrize("shape", [(8,), (8, 8), (8, 8, 8)])
483
- def test_cov(shape):
484
- x = np.random.randn(*shape)
485
- if x.ndim <= 2:
486
- assert np.allclose(_cov(x), np.cov(x))
487
- else:
488
- with pytest.raises(ValueError):
489
- _cov(x)
490
-
491
-
492
- @pytest.mark.parametrize(
493
- "kwargs",
494
- [
495
- {"cumulative": True},
496
- {"cumulative": True, "plot_kwargs": {"linestyle": "--"}},
497
- {"rug": True},
498
- {"rug": True, "rug_kwargs": {"alpha": 0.2}, "rotated": True},
499
- ],
500
- )
501
- def test_plot_kde_cumulative(continuous_model, kwargs):
502
- axes = plot_kde(continuous_model["x"], quantiles=[0.25, 0.5, 0.75], **kwargs)
503
- assert axes
504
-
505
-
506
- @pytest.mark.parametrize(
507
- "kwargs",
508
- [
509
- {"kind": "hist"},
510
- {"kind": "kde"},
511
- {"is_circular": False},
512
- {"is_circular": False, "kind": "hist"},
513
- {"is_circular": True},
514
- {"is_circular": True, "kind": "hist"},
515
- {"is_circular": "radians"},
516
- {"is_circular": "radians", "kind": "hist"},
517
- {"is_circular": "degrees"},
518
- {"is_circular": "degrees", "kind": "hist"},
519
- ],
520
- )
521
- def test_plot_dist(continuous_model, kwargs):
522
- axes = plot_dist(continuous_model["x"], **kwargs)
523
- axes1 = plot_dist(continuous_model["x"], **kwargs)
524
- assert axes
525
- assert axes is axes1
526
-
527
-
528
- def test_plot_dist_hist(data_random):
529
- axes = plot_dist(data_random, hist_kwargs=dict(bins=30))
530
- assert axes
531
-
532
-
533
- def test_list_conversion(data_list):
534
- axes = plot_dist(data_list, hist_kwargs=dict(bins=30))
535
- assert axes
536
-
537
-
538
- @pytest.mark.parametrize(
539
- "kwargs",
540
- [
541
- {"plot_kwargs": {"linestyle": "-"}},
542
- {"contour": True, "fill_last": False},
543
- {"contour": False},
544
- ],
545
- )
546
- def test_plot_dist_2d_kde(continuous_model, kwargs):
547
- axes = plot_dist(continuous_model["x"], continuous_model["y"], **kwargs)
548
- assert axes
549
-
550
-
551
- @pytest.mark.parametrize(
552
- "kwargs", [{"plot_kwargs": {"linestyle": "-"}}, {"cumulative": True}, {"rug": True}]
553
- )
554
- def test_plot_kde_quantiles(continuous_model, kwargs):
555
- axes = plot_kde(continuous_model["x"], quantiles=[0.05, 0.5, 0.95], **kwargs)
556
- assert axes
557
-
558
-
559
- def test_plot_kde_inference_data(models):
560
- """
561
- Ensure that an exception is raised when plot_kde
562
- is used with an inference data or Xarray dataset object.
563
- """
564
- with pytest.raises(ValueError, match="Inference Data"):
565
- plot_kde(models.model_1)
566
- with pytest.raises(ValueError, match="Xarray"):
567
- plot_kde(models.model_1.posterior)
568
-
569
-
570
- @pytest.mark.slow
571
- @pytest.mark.parametrize(
572
- "kwargs",
573
- [
574
- {
575
- "var_names": "theta",
576
- "divergences": True,
577
- "coords": {"school": [0, 1]},
578
- "scatter_kwargs": {"marker": "x", "c": "C0"},
579
- "divergences_kwargs": {"marker": "*", "c": "C0"},
580
- },
581
- {
582
- "divergences": True,
583
- "scatter_kwargs": {"marker": "x", "c": "C0"},
584
- "divergences_kwargs": {"marker": "*", "c": "C0"},
585
- "var_names": ["theta", "mu"],
586
- },
587
- {"kind": "kde", "var_names": ["theta"]},
588
- {"kind": "hexbin", "colorbar": False, "var_names": ["theta"]},
589
- {"kind": "hexbin", "colorbar": True, "var_names": ["theta"]},
590
- {
591
- "kind": "hexbin",
592
- "var_names": ["theta"],
593
- "coords": {"school": [0, 1]},
594
- "colorbar": True,
595
- "hexbin_kwargs": {"cmap": "viridis"},
596
- "textsize": 20,
597
- },
598
- {
599
- "point_estimate": "mean",
600
- "reference_values": {"mu": 0, "tau": 0},
601
- "reference_values_kwargs": {"c": "C0", "marker": "*"},
602
- },
603
- {
604
- "var_names": ["mu", "tau"],
605
- "reference_values": {"mu": 0, "tau": 0},
606
- "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
607
- },
608
- {
609
- "var_names": ["theta"],
610
- "reference_values": {"theta": [0.0] * 8},
611
- "labeller": MapLabeller({"theta": r"$\theta$"}),
612
- },
613
- {
614
- "var_names": ["theta"],
615
- "reference_values": {"theta": np.zeros(8)},
616
- "labeller": MapLabeller({"theta": r"$\theta$"}),
617
- },
618
- ],
619
- )
620
- def test_plot_pair(models, kwargs):
621
- ax = plot_pair(models.model_1, **kwargs)
622
- assert np.all(ax)
623
-
624
-
625
- @pytest.mark.parametrize(
626
- "kwargs", [{"kind": "scatter"}, {"kind": "kde"}, {"kind": "hexbin", "colorbar": True}]
627
- )
628
- def test_plot_pair_2var(discrete_model, fig_ax, kwargs):
629
- _, ax = fig_ax
630
- ax = plot_pair(discrete_model, ax=ax, **kwargs)
631
- assert ax
632
-
633
-
634
- def test_plot_pair_bad(models):
635
- with pytest.raises(ValueError):
636
- plot_pair(models.model_1, kind="bad_kind")
637
- with pytest.raises(Exception):
638
- plot_pair(models.model_1, var_names=["mu"])
639
-
640
-
641
- @pytest.mark.parametrize("has_sample_stats", [True, False])
642
- def test_plot_pair_divergences_warning(has_sample_stats):
643
- data = load_arviz_data("centered_eight")
644
- if has_sample_stats:
645
- # sample_stats present, diverging field missing
646
- data.sample_stats = data.sample_stats.rename({"diverging": "diverging_missing"})
647
- else:
648
- # sample_stats missing
649
- data = data.posterior # pylint: disable=no-member
650
- with pytest.warns(UserWarning):
651
- ax = plot_pair(data, divergences=True)
652
- assert np.all(ax)
653
-
654
-
655
- @pytest.mark.parametrize(
656
- "kwargs", [{}, {"marginals": True}, {"marginals": True, "var_names": ["mu", "tau"]}]
657
- )
658
- def test_plot_pair_overlaid(models, kwargs):
659
- ax = plot_pair(models.model_1, **kwargs)
660
- ax2 = plot_pair(models.model_2, ax=ax, **kwargs)
661
- assert ax is ax2
662
- assert ax.shape
663
-
664
-
665
- @pytest.mark.parametrize("marginals", [True, False])
666
- def test_plot_pair_combinedims(models, marginals):
667
- ax = plot_pair(
668
- models.model_1, var_names=["eta", "theta"], combine_dims={"school"}, marginals=marginals
669
- )
670
- if marginals:
671
- assert ax.shape == (2, 2)
672
- else:
673
- assert not isinstance(ax, np.ndarray)
674
-
675
-
676
- @pytest.mark.parametrize("marginals", [True, False])
677
- @pytest.mark.parametrize("max_subplots", [True, False])
678
- def test_plot_pair_shapes(marginals, max_subplots):
679
- rng = np.random.default_rng()
680
- idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
681
- if max_subplots:
682
- with rc_context({"plot.max_subplots": 6}):
683
- with pytest.warns(UserWarning, match="3x3 grid"):
684
- ax = plot_pair(idata, marginals=marginals)
685
- else:
686
- ax = plot_pair(idata, marginals=marginals)
687
- side = 3 if max_subplots else (4 + marginals)
688
- assert ax.shape == (side, side)
689
-
690
-
691
- @pytest.mark.parametrize("sharex", ["col", None])
692
- @pytest.mark.parametrize("sharey", ["row", None])
693
- @pytest.mark.parametrize("marginals", [True, False])
694
- def test_plot_pair_shared(sharex, sharey, marginals):
695
- # Generate fake data and plot
696
- rng = np.random.default_rng()
697
- idata = from_dict({"a": rng.standard_normal((4, 500, 5))})
698
- numvars = 5 - (not marginals)
699
- if sharex is None and sharey is None:
700
- ax = plot_pair(idata, marginals=marginals)
701
- else:
702
- backend_kwargs = {}
703
- if sharex is not None:
704
- backend_kwargs["sharex"] = sharex
705
- if sharey is not None:
706
- backend_kwargs["sharey"] = sharey
707
- with pytest.warns(UserWarning):
708
- ax = plot_pair(idata, marginals=marginals, backend_kwargs=backend_kwargs)
709
-
710
- # Check x axes shared correctly
711
- for i in range(numvars):
712
- num_shared_x = numvars - i
713
- assert len(ax[-1, i].get_shared_x_axes().get_siblings(ax[-1, i])) == num_shared_x
714
-
715
- # Check y axes shared correctly
716
- for j in range(numvars):
717
- if marginals:
718
- num_shared_y = j
719
-
720
- # Check diagonal has unshared axis
721
- assert len(ax[j, j].get_shared_y_axes().get_siblings(ax[j, j])) == 1
722
-
723
- if j == 0:
724
- continue
725
- else:
726
- num_shared_y = j + 1
727
- assert len(ax[j, 0].get_shared_y_axes().get_siblings(ax[j, 0])) == num_shared_y
728
-
729
-
730
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
731
- @pytest.mark.parametrize("alpha", [None, 0.2, 1])
732
- @pytest.mark.parametrize("animated", [False, True])
733
- @pytest.mark.parametrize("observed", [True, False])
734
- @pytest.mark.parametrize("observed_rug", [False, True])
735
- def test_plot_ppc(models, kind, alpha, animated, observed, observed_rug):
736
- if animation and not animation.writers.is_available("ffmpeg"):
737
- pytest.skip("matplotlib animations within ArviZ require ffmpeg")
738
- animation_kwargs = {"blit": False}
739
- axes = plot_ppc(
740
- models.model_1,
741
- kind=kind,
742
- alpha=alpha,
743
- observed=observed,
744
- observed_rug=observed_rug,
745
- animated=animated,
746
- animation_kwargs=animation_kwargs,
747
- random_seed=3,
748
- )
749
- if animated:
750
- assert axes[0]
751
- assert axes[1]
752
- assert axes
753
-
754
-
755
- def test_plot_ppc_transposed():
756
- idata = load_arviz_data("rugby")
757
- idata.map(
758
- lambda ds: ds.assign(points=xr.concat((ds.home_points, ds.away_points), "field")),
759
- groups="observed_vars",
760
- inplace=True,
761
- )
762
- assert idata.posterior_predictive.points.dims == ("field", "chain", "draw", "match")
763
- ax = plot_ppc(
764
- idata,
765
- kind="scatter",
766
- var_names="points",
767
- flatten=["field"],
768
- coords={"match": ["Wales Italy"]},
769
- random_seed=3,
770
- num_pp_samples=8,
771
- )
772
- x, y = ax.get_lines()[2].get_data()
773
- assert not np.isclose(y[0], 0)
774
- assert np.all(np.array([47, 44, 15, 11]) == x)
775
-
776
-
777
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
778
- @pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
779
- @pytest.mark.parametrize("animated", [False, True])
780
- def test_plot_ppc_multichain(kind, jitter, animated):
781
- if animation and not animation.writers.is_available("ffmpeg"):
782
- pytest.skip("matplotlib animations within ArviZ require ffmpeg")
783
- data = from_dict(
784
- posterior_predictive={
785
- "x": np.random.randn(4, 100, 30),
786
- "y_hat": np.random.randn(4, 100, 3, 10),
787
- },
788
- observed_data={"x": np.random.randn(30), "y": np.random.randn(3, 10)},
789
- )
790
- animation_kwargs = {"blit": False}
791
- axes = plot_ppc(
792
- data,
793
- kind=kind,
794
- data_pairs={"y": "y_hat"},
795
- jitter=jitter,
796
- animated=animated,
797
- animation_kwargs=animation_kwargs,
798
- random_seed=3,
799
- )
800
- if animated:
801
- assert np.all(axes[0])
802
- assert np.all(axes[1])
803
- else:
804
- assert np.all(axes)
805
-
806
-
807
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
808
- @pytest.mark.parametrize("animated", [False, True])
809
- def test_plot_ppc_discrete(kind, animated):
810
- if animation and not animation.writers.is_available("ffmpeg"):
811
- pytest.skip("matplotlib animations within ArviZ require ffmpeg")
812
- data = from_dict(
813
- observed_data={"obs": np.random.randint(1, 100, 15)},
814
- posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
815
- )
816
-
817
- animation_kwargs = {"blit": False}
818
- axes = plot_ppc(data, kind=kind, animated=animated, animation_kwargs=animation_kwargs)
819
- if animated:
820
- assert np.all(axes[0])
821
- assert np.all(axes[1])
822
- assert axes
823
-
824
-
825
- @pytest.mark.skipif(
826
- not animation.writers.is_available("ffmpeg"),
827
- reason="matplotlib animations within ArviZ require ffmpeg",
828
- )
829
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
830
- def test_plot_ppc_save_animation(models, kind):
831
- animation_kwargs = {"blit": False}
832
- axes, anim = plot_ppc(
833
- models.model_1,
834
- kind=kind,
835
- animated=True,
836
- animation_kwargs=animation_kwargs,
837
- num_pp_samples=5,
838
- random_seed=3,
839
- )
840
- assert axes
841
- assert anim
842
- animations_folder = "../saved_animations"
843
- os.makedirs(animations_folder, exist_ok=True)
844
- path = os.path.join(animations_folder, f"ppc_{kind}_animation.mp4")
845
- anim.save(path)
846
- assert os.path.exists(path)
847
- assert os.path.getsize(path)
848
-
849
-
850
- @pytest.mark.skipif(
851
- not animation.writers.is_available("ffmpeg"),
852
- reason="matplotlib animations within ArviZ require ffmpeg",
853
- )
854
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
855
- def test_plot_ppc_discrete_save_animation(kind):
856
- data = from_dict(
857
- observed_data={"obs": np.random.randint(1, 100, 15)},
858
- posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
859
- )
860
- animation_kwargs = {"blit": False}
861
- axes, anim = plot_ppc(
862
- data,
863
- kind=kind,
864
- animated=True,
865
- animation_kwargs=animation_kwargs,
866
- num_pp_samples=5,
867
- random_seed=3,
868
- )
869
- assert axes
870
- assert anim
871
- animations_folder = "../saved_animations"
872
- os.makedirs(animations_folder, exist_ok=True)
873
- path = os.path.join(animations_folder, f"ppc_discrete_{kind}_animation.mp4")
874
- anim.save(path)
875
- assert os.path.exists(path)
876
- assert os.path.getsize(path)
877
-
878
-
879
- @pytest.mark.skipif(
880
- not animation.writers.is_available("ffmpeg"),
881
- reason="matplotlib animations within ArviZ require ffmpeg",
882
- )
883
- @pytest.mark.parametrize("system", ["Windows", "Darwin"])
884
- def test_non_linux_blit(models, monkeypatch, system, caplog):
885
- import platform
886
-
887
- def mock_system():
888
- return system
889
-
890
- monkeypatch.setattr(platform, "system", mock_system)
891
-
892
- animation_kwargs = {"blit": True}
893
- axes, anim = plot_ppc(
894
- models.model_1,
895
- kind="kde",
896
- animated=True,
897
- animation_kwargs=animation_kwargs,
898
- num_pp_samples=5,
899
- random_seed=3,
900
- )
901
- records = caplog.records
902
- assert len(records) == 1
903
- assert records[0].levelname == "WARNING"
904
- assert axes
905
- assert anim
906
-
907
-
908
- @pytest.mark.parametrize(
909
- "kwargs",
910
- [
911
- {"flatten": []},
912
- {"flatten": [], "coords": {"obs_dim": [1, 2, 3]}},
913
- {"flatten": ["obs_dim"], "coords": {"obs_dim": [1, 2, 3]}},
914
- ],
915
- )
916
- def test_plot_ppc_grid(models, kwargs):
917
- axes = plot_ppc(models.model_1, kind="scatter", **kwargs)
918
- if not kwargs.get("flatten") and not kwargs.get("coords"):
919
- assert axes.size == 8
920
- elif not kwargs.get("flatten"):
921
- assert axes.size == 3
922
- else:
923
- assert not isinstance(axes, np.ndarray)
924
- assert np.ravel(axes).size == 1
925
-
926
-
927
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
928
- def test_plot_ppc_bad(models, kind):
929
- data = from_dict(posterior={"mu": np.random.randn()})
930
- with pytest.raises(TypeError):
931
- plot_ppc(data, kind=kind)
932
- with pytest.raises(TypeError):
933
- plot_ppc(models.model_1, kind="bad_val")
934
- with pytest.raises(TypeError):
935
- plot_ppc(models.model_1, num_pp_samples="bad_val")
936
-
937
-
938
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
939
- def test_plot_ppc_ax(models, kind, fig_ax):
940
- """Test ax argument of plot_ppc."""
941
- _, ax = fig_ax
942
- axes = plot_ppc(models.model_1, kind=kind, ax=ax)
943
- assert np.asarray(axes).item(0) is ax
944
-
945
-
946
- @pytest.mark.skipif(
947
- not animation.writers.is_available("ffmpeg"),
948
- reason="matplotlib animations within ArviZ require ffmpeg",
949
- )
950
- def test_plot_ppc_bad_ax(models, fig_ax):
951
- _, ax = fig_ax
952
- _, ax2 = plt.subplots(1, 2)
953
- with pytest.raises(ValueError, match="same figure"):
954
- plot_ppc(
955
- models.model_1, ax=[ax, *ax2], flatten=[], coords={"obs_dim": [1, 2, 3]}, animated=True
956
- )
957
- with pytest.raises(ValueError, match="2 axes"):
958
- plot_ppc(models.model_1, ax=ax2)
959
-
960
-
961
- def test_plot_legend(models):
962
- axes = plot_ppc(models.model_1)
963
- legend_texts = axes.get_legend().get_texts()
964
- result = [i.get_text() for i in legend_texts]
965
- expected = ["Posterior predictive", "Observed", "Posterior predictive mean"]
966
- assert result == expected
967
-
968
-
969
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
970
- @pytest.mark.parametrize("side", ["both", "left", "right"])
971
- @pytest.mark.parametrize("rug", [True])
972
- def test_plot_violin(models, var_names, side, rug):
973
- axes = plot_violin(models.model_1, var_names=var_names, side=side, rug=rug)
974
- assert axes.shape
975
-
976
-
977
- def test_plot_violin_ax(models):
978
- _, ax = plt.subplots(1)
979
- axes = plot_violin(models.model_1, var_names="mu", ax=ax)
980
- assert axes.shape
981
-
982
-
983
- def test_plot_violin_layout(models):
984
- axes = plot_violin(models.model_1, var_names=["mu", "tau"], sharey=False)
985
- assert axes.shape
986
-
987
-
988
- def test_plot_violin_discrete(discrete_model):
989
- axes = plot_violin(discrete_model)
990
- assert axes.shape
991
-
992
-
993
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
994
- def test_plot_violin_combinedims(models, var_names):
995
- axes = plot_violin(models.model_1, var_names=var_names, combine_dims={"school"})
996
- assert axes.shape
997
-
998
-
999
- def test_plot_violin_ax_combinedims(models):
1000
- _, ax = plt.subplots(1)
1001
- axes = plot_violin(models.model_1, var_names="mu", combine_dims={"school"}, ax=ax)
1002
- assert axes.shape
1003
-
1004
-
1005
- def test_plot_violin_layout_combinedims(models):
1006
- axes = plot_violin(
1007
- models.model_1, var_names=["mu", "tau"], combine_dims={"school"}, sharey=False
1008
- )
1009
- assert axes.shape
1010
-
1011
-
1012
- def test_plot_violin_discrete_combinedims(discrete_model):
1013
- axes = plot_violin(discrete_model, combine_dims={"school"})
1014
- assert axes.shape
1015
-
1016
-
1017
- def test_plot_autocorr_short_chain():
1018
- """Check that logic for small chain defaulting doesn't cause exception"""
1019
- chain = np.arange(10)
1020
- axes = plot_autocorr(chain)
1021
- assert axes
1022
-
1023
-
1024
- def test_plot_autocorr_uncombined(models):
1025
- axes = plot_autocorr(models.model_1, combined=False)
1026
- assert axes.size
1027
- max_subplots = (
1028
- np.inf if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
1029
- )
1030
- assert axes.size == min(72, max_subplots)
1031
-
1032
-
1033
- def test_plot_autocorr_combined(models):
1034
- axes = plot_autocorr(models.model_1, combined=True)
1035
- assert axes.size == 18
1036
-
1037
-
1038
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu"], ["mu", "tau"]))
1039
- def test_plot_autocorr_var_names(models, var_names):
1040
- axes = plot_autocorr(models.model_1, var_names=var_names, combined=True)
1041
- if (isinstance(var_names, list) and len(var_names) == 1) or isinstance(var_names, str):
1042
- assert not isinstance(axes, np.ndarray)
1043
- else:
1044
- assert axes.shape
1045
-
1046
-
1047
- @pytest.mark.parametrize(
1048
- "kwargs",
1049
- [
1050
- {},
1051
- {"var_names": "mu"},
1052
- {"var_names": ("mu", "tau"), "coords": {"school": [0, 1]}},
1053
- {"var_names": "mu", "ref_line": True},
1054
- {
1055
- "var_names": "mu",
1056
- "ref_line_kwargs": {"lw": 2, "color": "C2"},
1057
- "bar_kwargs": {"width": 0.7},
1058
- },
1059
- {"var_names": "mu", "ref_line": False},
1060
- {"var_names": "mu", "kind": "vlines"},
1061
- {
1062
- "var_names": "mu",
1063
- "kind": "vlines",
1064
- "vlines_kwargs": {"lw": 0},
1065
- "marker_vlines_kwargs": {"lw": 3},
1066
- },
1067
- ],
1068
- )
1069
- def test_plot_rank(models, kwargs):
1070
- axes = plot_rank(models.model_1, **kwargs)
1071
- var_names = kwargs.get("var_names", [])
1072
- if isinstance(var_names, str):
1073
- assert not isinstance(axes, np.ndarray)
1074
- else:
1075
- assert axes.shape
1076
-
1077
-
1078
- @pytest.mark.parametrize(
1079
- "kwargs",
1080
- [
1081
- {},
1082
- {"var_names": "mu"},
1083
- {"var_names": ("mu", "tau")},
1084
- {"rope": (-2, 2)},
1085
- {"rope": {"mu": [{"rope": (-2, 2)}], "theta": [{"school": "Choate", "rope": (2, 4)}]}},
1086
- {"point_estimate": "mode"},
1087
- {"point_estimate": "median"},
1088
- {"hdi_prob": "hide", "label": ""},
1089
- {"point_estimate": None},
1090
- {"ref_val": 0},
1091
- {"ref_val": None},
1092
- {"ref_val": {"mu": [{"ref_val": 1}]}},
1093
- {"bins": None, "kind": "hist"},
1094
- {
1095
- "ref_val": {
1096
- "theta": [
1097
- # {"school": ["Choate", "Deerfield"], "ref_val": -1}, this is not working
1098
- {"school": "Lawrenceville", "ref_val": 3}
1099
- ]
1100
- }
1101
- },
1102
- ],
1103
- )
1104
- def test_plot_posterior(models, kwargs):
1105
- axes = plot_posterior(models.model_1, **kwargs)
1106
- if isinstance(kwargs.get("var_names"), str):
1107
- assert not isinstance(axes, np.ndarray)
1108
- else:
1109
- assert axes.shape
1110
-
1111
-
1112
- def test_plot_posterior_boolean():
1113
- data = np.random.choice(a=[False, True], size=(4, 100))
1114
- axes = plot_posterior(data)
1115
- assert axes
1116
- plt.draw()
1117
- labels = [label.get_text() for label in axes.get_xticklabels()]
1118
- assert all(item in labels for item in ("True", "False"))
1119
-
1120
-
1121
- @pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
1122
- def test_plot_posterior_discrete(discrete_model, kwargs):
1123
- axes = plot_posterior(discrete_model, **kwargs)
1124
- assert axes.shape
1125
-
1126
-
1127
- def test_plot_posterior_bad_type():
1128
- with pytest.raises(TypeError):
1129
- plot_posterior(np.array(["a", "b", "c"]))
1130
-
1131
-
1132
- def test_plot_posterior_bad(models):
1133
- with pytest.raises(ValueError):
1134
- plot_posterior(models.model_1, rope="bad_value")
1135
- with pytest.raises(ValueError):
1136
- plot_posterior(models.model_1, ref_val="bad_value")
1137
- with pytest.raises(ValueError):
1138
- plot_posterior(models.model_1, point_estimate="bad_value")
1139
-
1140
-
1141
- @pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
1142
- def test_plot_posterior_point_estimates(models, point_estimate):
1143
- axes = plot_posterior(models.model_1, var_names=("mu", "tau"), point_estimate=point_estimate)
1144
- assert axes.size == 2
1145
-
1146
-
1147
- def test_plot_posterior_skipna():
1148
- sample = np.linspace(0, 1)
1149
- sample[:10] = np.nan
1150
- plot_posterior({"a": sample}, skipna=True)
1151
- with pytest.raises(ValueError):
1152
- plot_posterior({"a": sample}, skipna=False)
1153
-
1154
-
1155
- @pytest.mark.parametrize("kwargs", [{"var_names": ["mu", "theta"]}])
1156
- def test_plot_posterior_combinedims(models, kwargs):
1157
- axes = plot_posterior(models.model_1, combine_dims={"school"}, **kwargs)
1158
- if isinstance(kwargs.get("var_names"), str):
1159
- assert not isinstance(axes, np.ndarray)
1160
- else:
1161
- assert axes.shape
1162
-
1163
-
1164
- @pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
1165
- def test_plot_posterior_discrete_combinedims(discrete_multidim_model, kwargs):
1166
- axes = plot_posterior(discrete_multidim_model, combine_dims={"school"}, **kwargs)
1167
- assert axes.size == 2
1168
-
1169
-
1170
- @pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
1171
- def test_plot_posterior_point_estimates_combinedims(models, point_estimate):
1172
- axes = plot_posterior(
1173
- models.model_1,
1174
- var_names=("mu", "tau"),
1175
- combine_dims={"school"},
1176
- point_estimate=point_estimate,
1177
- )
1178
- assert axes.size == 2
1179
-
1180
-
1181
- def test_plot_posterior_skipna_combinedims():
1182
- idata = load_arviz_data("centered_eight")
1183
- idata.posterior["theta"].loc[dict(school="Deerfield")] = np.nan
1184
- with pytest.raises(ValueError):
1185
- plot_posterior(idata, var_names="theta", combine_dims={"school"}, skipna=False)
1186
- ax = plot_posterior(idata, var_names="theta", combine_dims={"school"}, skipna=True)
1187
- assert not isinstance(ax, np.ndarray)
1188
-
1189
-
1190
- @pytest.mark.parametrize(
1191
- "kwargs", [{"insample_dev": True}, {"plot_standard_error": False}, {"plot_ic_diff": False}]
1192
- )
1193
- def test_plot_compare(models, kwargs):
1194
- model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
1195
-
1196
- axes = plot_compare(model_compare, **kwargs)
1197
- assert axes
1198
-
1199
-
1200
- def test_plot_compare_no_ic(models):
1201
- """Check exception is raised if model_compare doesn't contain a valid information criterion"""
1202
- model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
1203
-
1204
- # Drop column needed for plotting
1205
- model_compare = model_compare.drop("elpd_loo", axis=1)
1206
- with pytest.raises(ValueError) as err:
1207
- plot_compare(model_compare)
1208
-
1209
- assert "comp_df must contain one of the following" in str(err.value)
1210
- assert "['elpd_loo', 'elpd_waic']" in str(err.value)
1211
-
1212
-
1213
- @pytest.mark.parametrize(
1214
- "kwargs",
1215
- [
1216
- {"color": "0.5", "circular": True},
1217
- {"hdi_data": True, "fill_kwargs": {"alpha": 0}},
1218
- {"plot_kwargs": {"alpha": 0}},
1219
- {"smooth_kwargs": {"window_length": 33, "polyorder": 5, "mode": "mirror"}},
1220
- {"hdi_data": True, "smooth": False},
1221
- ],
1222
- )
1223
- def test_plot_hdi(models, data, kwargs):
1224
- hdi_data = kwargs.pop("hdi_data", None)
1225
- if hdi_data:
1226
- hdi_data = hdi(models.model_1.posterior["theta"])
1227
- ax = plot_hdi(data["y"], hdi_data=hdi_data, **kwargs)
1228
- else:
1229
- ax = plot_hdi(data["y"], models.model_1.posterior["theta"], **kwargs)
1230
- assert ax
1231
-
1232
-
1233
- def test_plot_hdi_warning():
1234
- """Check using both y and hdi_data sends a warning."""
1235
- x_data = np.random.normal(0, 1, 100)
1236
- y_data = np.random.normal(2 + x_data * 0.5, 0.5, (1, 200, 100))
1237
- hdi_data = hdi(y_data)
1238
- with pytest.warns(UserWarning, match="Both y and hdi_data"):
1239
- ax = plot_hdi(x_data, y=y_data, hdi_data=hdi_data)
1240
- assert ax
1241
-
1242
-
1243
- def test_plot_hdi_missing_arg_error():
1244
- """Check that both y and hdi_data missing raises an error."""
1245
- with pytest.raises(ValueError, match="One of {y, hdi_data"):
1246
- plot_hdi(np.arange(20))
1247
-
1248
-
1249
- def test_plot_hdi_dataset_error(models):
1250
- """Check hdi_data as multiple variable Dataset raises an error."""
1251
- hdi_data = hdi(models.model_1)
1252
- with pytest.raises(ValueError, match="Only single variable Dataset"):
1253
- plot_hdi(np.arange(8), hdi_data=hdi_data)
1254
-
1255
-
1256
- def test_plot_hdi_string_error():
1257
- """Check x as type string raises an error."""
1258
- x_data = ["a", "b", "c", "d"]
1259
- y_data = np.random.normal(0, 5, (1, 200, len(x_data)))
1260
- hdi_data = hdi(y_data)
1261
- with pytest.raises(
1262
- NotImplementedError,
1263
- match=re.escape(
1264
- (
1265
- "The `arviz.plot_hdi()` function does not support categorical data. "
1266
- "Consider using `arviz.plot_forest()`."
1267
- )
1268
- ),
1269
- ):
1270
- plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
1271
-
1272
-
1273
- def test_plot_hdi_datetime_error():
1274
- """Check x as datetime raises an error."""
1275
- x_data = np.arange(start="2022-01-01", stop="2022-03-01", dtype=np.datetime64)
1276
- y_data = np.random.normal(0, 5, (1, 200, x_data.shape[0]))
1277
- hdi_data = hdi(y_data)
1278
- with pytest.raises(TypeError, match="Cannot deal with x as type datetime."):
1279
- plot_hdi(x=x_data, y=y_data, hdi_data=hdi_data)
1280
-
1281
-
1282
- @pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)])
1283
- def test_kde_scipy(limits):
1284
- """
1285
- Evaluates if sum of density is the same for our implementation
1286
- and the implementation in scipy
1287
- """
1288
- data = np.random.normal(0, 1, 10000)
1289
- grid, density_own = _kde(data, custom_lims=limits)
1290
- density_sp = gaussian_kde(data).evaluate(grid)
1291
- np.testing.assert_almost_equal(density_own.sum(), density_sp.sum(), 1)
1292
-
1293
-
1294
- @pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)])
1295
- def test_kde_cumulative(limits):
1296
- """
1297
- Evaluates if last value of cumulative density is 1
1298
- """
1299
- data = np.random.normal(0, 1, 1000)
1300
- density = _kde(data, custom_lims=limits, cumulative=True)[1]
1301
- np.testing.assert_almost_equal(round(density[-1], 3), 1)
1302
-
1303
-
1304
- def test_plot_ecdf_basic():
1305
- data = np.random.randn(4, 1000)
1306
- axes = plot_ecdf(data)
1307
- assert axes is not None
1308
-
1309
-
1310
- def test_plot_ecdf_eval_points():
1311
- """Check that BehaviourChangeWarning is raised if eval_points is not specified."""
1312
- data = np.random.randn(4, 1000)
1313
- eval_points = np.linspace(-3, 3, 100)
1314
- with pytest.warns(BehaviourChangeWarning):
1315
- axes = plot_ecdf(data)
1316
- assert axes is not None
1317
- with does_not_warn(BehaviourChangeWarning):
1318
- axes = plot_ecdf(data, eval_points=eval_points)
1319
- assert axes is not None
1320
-
1321
-
1322
- @pytest.mark.parametrize("confidence_bands", [True, "pointwise", "optimized", "simulated"])
1323
- @pytest.mark.parametrize("ndraws", [100, 10_000])
1324
- def test_plot_ecdf_confidence_bands(confidence_bands, ndraws):
1325
- """Check that all confidence_bands values correctly accepted"""
1326
- data = np.random.randn(4, ndraws // 4)
1327
- axes = plot_ecdf(data, confidence_bands=confidence_bands, cdf=norm(0, 1).cdf)
1328
- assert axes is not None
1329
-
1330
-
1331
- def test_plot_ecdf_values2():
1332
- data = np.random.randn(4, 1000)
1333
- data2 = np.random.randn(4, 1000)
1334
- axes = plot_ecdf(data, data2)
1335
- assert axes is not None
1336
-
1337
-
1338
- def test_plot_ecdf_cdf():
1339
- data = np.random.randn(4, 1000)
1340
- cdf = norm(0, 1).cdf
1341
- axes = plot_ecdf(data, cdf=cdf)
1342
- assert axes is not None
1343
-
1344
-
1345
- def test_plot_ecdf_error():
1346
- """Check that all error conditions are correctly raised."""
1347
- dist = norm(0, 1)
1348
- data = dist.rvs(1000)
1349
-
1350
- # cdf not specified
1351
- with pytest.raises(ValueError):
1352
- plot_ecdf(data, confidence_bands=True)
1353
- plot_ecdf(data, confidence_bands=True, cdf=dist.cdf)
1354
- with pytest.raises(ValueError):
1355
- plot_ecdf(data, difference=True)
1356
- plot_ecdf(data, difference=True, cdf=dist.cdf)
1357
- with pytest.raises(ValueError):
1358
- plot_ecdf(data, pit=True)
1359
- plot_ecdf(data, pit=True, cdf=dist.cdf)
1360
-
1361
- # contradictory confidence band types
1362
- with pytest.raises(ValueError):
1363
- plot_ecdf(data, cdf=dist.cdf, confidence_bands="simulated", pointwise=True)
1364
- with pytest.raises(ValueError):
1365
- plot_ecdf(data, cdf=dist.cdf, confidence_bands="optimized", pointwise=True)
1366
- plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, pointwise=True)
1367
- plot_ecdf(data, cdf=dist.cdf, confidence_bands="pointwise")
1368
-
1369
- # contradictory band probabilities
1370
- with pytest.raises(ValueError):
1371
- plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, ci_prob=0.9, fpr=0.1)
1372
- plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, ci_prob=0.9)
1373
- plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, fpr=0.1)
1374
-
1375
- # contradictory reference
1376
- data2 = dist.rvs(200)
1377
- with pytest.raises(ValueError):
1378
- plot_ecdf(data, data2, cdf=dist.cdf, difference=True)
1379
- plot_ecdf(data, data2, difference=True)
1380
- plot_ecdf(data, cdf=dist.cdf, difference=True)
1381
-
1382
-
1383
- def test_plot_ecdf_deprecations():
1384
- """Check that deprecations are raised correctly."""
1385
- dist = norm(0, 1)
1386
- data = dist.rvs(1000)
1387
- # base case, no deprecations
1388
- with does_not_warn(FutureWarning):
1389
- axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True)
1390
- assert axes is not None
1391
-
1392
- # values2 is deprecated
1393
- data2 = dist.rvs(200)
1394
- with pytest.warns(FutureWarning):
1395
- axes = plot_ecdf(data, values2=data2, difference=True)
1396
-
1397
- # pit is deprecated
1398
- with pytest.warns(FutureWarning):
1399
- axes = plot_ecdf(data, cdf=dist.cdf, pit=True)
1400
- assert axes is not None
1401
-
1402
- # fpr is deprecated
1403
- with does_not_warn(FutureWarning):
1404
- axes = plot_ecdf(data, cdf=dist.cdf, ci_prob=0.9)
1405
- with pytest.warns(FutureWarning):
1406
- axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, fpr=0.1)
1407
- assert axes is not None
1408
-
1409
- # pointwise is deprecated
1410
- with does_not_warn(FutureWarning):
1411
- axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands="pointwise")
1412
- with pytest.warns(FutureWarning):
1413
- axes = plot_ecdf(data, cdf=dist.cdf, confidence_bands=True, pointwise=True)
1414
-
1415
-
1416
- @pytest.mark.parametrize(
1417
- "kwargs",
1418
- [
1419
- {},
1420
- {"ic": "loo"},
1421
- {"xlabels": True, "scale": "log"},
1422
- {"color": "obs_dim", "xlabels": True},
1423
- {"color": "obs_dim", "legend": True},
1424
- {"ic": "loo", "color": "blue", "coords": {"obs_dim": slice(2, 5)}},
1425
- {"color": np.random.uniform(size=8), "threshold": 0.1},
1426
- {"threshold": 2},
1427
- ],
1428
- )
1429
- @pytest.mark.parametrize("add_model", [False, True])
1430
- @pytest.mark.parametrize("use_elpddata", [False, True])
1431
- def test_plot_elpd(models, add_model, use_elpddata, kwargs):
1432
- model_dict = {"Model 1": models.model_1, "Model 2": models.model_2}
1433
- if add_model:
1434
- model_dict["Model 3"] = create_model(seed=12)
1435
-
1436
- if use_elpddata:
1437
- ic = kwargs.get("ic", "waic")
1438
- scale = kwargs.get("scale", "deviance")
1439
- if ic == "waic":
1440
- model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
1441
- else:
1442
- model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
1443
-
1444
- axes = plot_elpd(model_dict, **kwargs)
1445
- assert np.all(axes)
1446
- if add_model:
1447
- assert axes.shape[0] == axes.shape[1]
1448
- assert axes.shape[0] == len(model_dict) - 1
1449
-
1450
-
1451
- @pytest.mark.parametrize(
1452
- "kwargs",
1453
- [
1454
- {},
1455
- {"ic": "loo"},
1456
- {"xlabels": True, "scale": "log"},
1457
- {"color": "dim1", "xlabels": True},
1458
- {"color": "dim2", "legend": True},
1459
- {"ic": "loo", "color": "blue", "coords": {"dim2": slice(2, 4)}},
1460
- {"color": np.random.uniform(size=35), "threshold": 0.1},
1461
- ],
1462
- )
1463
- @pytest.mark.parametrize("add_model", [False, True])
1464
- @pytest.mark.parametrize("use_elpddata", [False, True])
1465
- def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs):
1466
- model_dict = {"Model 1": multidim_models.model_1, "Model 2": multidim_models.model_2}
1467
- if add_model:
1468
- model_dict["Model 3"] = create_multidimensional_model(seed=12)
1469
-
1470
- if use_elpddata:
1471
- ic = kwargs.get("ic", "waic")
1472
- scale = kwargs.get("scale", "deviance")
1473
- if ic == "waic":
1474
- model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
1475
- else:
1476
- model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
1477
-
1478
- axes = plot_elpd(model_dict, **kwargs)
1479
- assert np.all(axes)
1480
- if add_model:
1481
- assert axes.shape[0] == axes.shape[1]
1482
- assert axes.shape[0] == len(model_dict) - 1
1483
-
1484
-
1485
- def test_plot_elpd_bad_ic(models):
1486
- model_dict = {
1487
- "Model 1": waic(models.model_1, pointwise=True),
1488
- "Model 2": loo(models.model_2, pointwise=True),
1489
- }
1490
- with pytest.raises(ValueError):
1491
- plot_elpd(model_dict, ic="bad_ic")
1492
-
1493
-
1494
- def test_plot_elpd_ic_error(models):
1495
- model_dict = {
1496
- "Model 1": waic(models.model_1, pointwise=True),
1497
- "Model 2": loo(models.model_2, pointwise=True),
1498
- }
1499
- with pytest.raises(ValueError):
1500
- plot_elpd(model_dict)
1501
-
1502
-
1503
- def test_plot_elpd_scale_error(models):
1504
- model_dict = {
1505
- "Model 1": waic(models.model_1, pointwise=True, scale="log"),
1506
- "Model 2": waic(models.model_2, pointwise=True, scale="deviance"),
1507
- }
1508
- with pytest.raises(ValueError):
1509
- plot_elpd(model_dict)
1510
-
1511
-
1512
- def test_plot_elpd_one_model(models):
1513
- model_dict = {"Model 1": models.model_1}
1514
- with pytest.raises(Exception):
1515
- plot_elpd(model_dict)
1516
-
1517
-
1518
- @pytest.mark.parametrize(
1519
- "kwargs",
1520
- [
1521
- {},
1522
- {"xlabels": True},
1523
- {"color": "obs_dim", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
1524
- {"color": "obs_dim", "legend": True, "hover_label": True},
1525
- {"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
1526
- {"color": np.random.uniform(size=8), "show_bins": True},
1527
- {
1528
- "color": np.random.uniform(size=(8, 3)),
1529
- "show_bins": True,
1530
- "show_hlines": True,
1531
- "threshold": 1,
1532
- },
1533
- ],
1534
- )
1535
- @pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
1536
- def test_plot_khat(models, input_type, kwargs):
1537
- khats_data = loo(models.model_1, pointwise=True)
1538
-
1539
- if input_type == "data_array":
1540
- khats_data = khats_data.pareto_k
1541
- elif input_type == "array":
1542
- khats_data = khats_data.pareto_k.values
1543
- if "color" in kwargs and isinstance(kwargs["color"], str) and kwargs["color"] == "obs_dim":
1544
- kwargs["color"] = None
1545
-
1546
- axes = plot_khat(khats_data, **kwargs)
1547
- assert axes
1548
-
1549
-
1550
- @pytest.mark.parametrize(
1551
- "kwargs",
1552
- [
1553
- {},
1554
- {"xlabels": True},
1555
- {"color": "dim1", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
1556
- {"color": "dim2", "legend": True, "hover_label": True},
1557
- {"color": "blue", "coords": {"dim2": slice(2, 4)}},
1558
- {"color": np.random.uniform(size=35), "show_bins": True},
1559
- {
1560
- "color": np.random.uniform(size=(35, 3)),
1561
- "show_bins": True,
1562
- "show_hlines": True,
1563
- "threshold": 1,
1564
- },
1565
- ],
1566
- )
1567
- @pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
1568
- def test_plot_khat_multidim(multidim_models, input_type, kwargs):
1569
- khats_data = loo(multidim_models.model_1, pointwise=True)
1570
-
1571
- if input_type == "data_array":
1572
- khats_data = khats_data.pareto_k
1573
- elif input_type == "array":
1574
- khats_data = khats_data.pareto_k.values
1575
- if (
1576
- "color" in kwargs
1577
- and isinstance(kwargs["color"], str)
1578
- and kwargs["color"] in ("dim1", "dim2")
1579
- ):
1580
- kwargs["color"] = None
1581
-
1582
- axes = plot_khat(khats_data, **kwargs)
1583
- assert axes
1584
-
1585
-
1586
- def test_plot_khat_threshold():
1587
- khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
1588
- axes = plot_khat(khats, threshold=1)
1589
- assert axes
1590
-
1591
-
1592
- def test_plot_khat_bad_input(models):
1593
- with pytest.raises(ValueError):
1594
- plot_khat(models.model_1.sample_stats)
1595
-
1596
-
1597
- @pytest.mark.parametrize(
1598
- "kwargs",
1599
- [
1600
- {},
1601
- {"var_names": ["theta"], "relative": True, "color": "r"},
1602
- {"coords": {"school": slice(4)}, "n_points": 10},
1603
- {"min_ess": 600, "hline_kwargs": {"color": "r"}},
1604
- ],
1605
- )
1606
- @pytest.mark.parametrize("kind", ["local", "quantile", "evolution"])
1607
- def test_plot_ess(models, kind, kwargs):
1608
- """Test plot_ess arguments common to all kind of plots."""
1609
- idata = models.model_1
1610
- ax = plot_ess(idata, kind=kind, **kwargs)
1611
- assert np.all(ax)
1612
-
1613
-
1614
- @pytest.mark.parametrize(
1615
- "kwargs",
1616
- [
1617
- {"rug": True},
1618
- {"rug": True, "rug_kind": "max_depth", "rug_kwargs": {"color": "c"}},
1619
- {"extra_methods": True},
1620
- {"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
1621
- {"extra_methods": True, "rug": True},
1622
- ],
1623
- )
1624
- @pytest.mark.parametrize("kind", ["local", "quantile"])
1625
- def test_plot_ess_local_quantile(models, kind, kwargs):
1626
- """Test specific arguments in kinds local and quantile of plot_ess."""
1627
- idata = models.model_1
1628
- ax = plot_ess(idata, kind=kind, **kwargs)
1629
- assert np.all(ax)
1630
-
1631
-
1632
- def test_plot_ess_evolution(models):
1633
- """Test specific arguments in evolution kind of plot_ess."""
1634
- idata = models.model_1
1635
- ax = plot_ess(idata, kind="evolution", extra_kwargs={"linestyle": "--"}, color="b")
1636
- assert np.all(ax)
1637
-
1638
-
1639
- def test_plot_ess_bad_kind(models):
1640
- """Test error when plot_ess receives an invalid kind."""
1641
- idata = models.model_1
1642
- with pytest.raises(ValueError, match="Invalid kind"):
1643
- plot_ess(idata, kind="bad kind")
1644
-
1645
-
1646
- @pytest.mark.parametrize("dim", ["chain", "draw"])
1647
- def test_plot_ess_bad_coords(models, dim):
1648
- """Test error when chain or dim are used as coords to select a data subset."""
1649
- idata = models.model_1
1650
- with pytest.raises(ValueError, match="invalid coordinates"):
1651
- plot_ess(idata, coords={dim: slice(3)})
1652
-
1653
-
1654
- def test_plot_ess_no_sample_stats(models):
1655
- """Test error when rug=True but sample_stats group is not present."""
1656
- idata = models.model_1
1657
- with pytest.raises(ValueError, match="must contain sample_stats"):
1658
- plot_ess(idata.posterior, rug=True)
1659
-
1660
-
1661
- def test_plot_ess_no_divergences(models):
1662
- """Test error when rug=True, but the variable defined by rug_kind is missing."""
1663
- idata = deepcopy(models.model_1)
1664
- idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
1665
- with pytest.raises(ValueError, match="not contain diverging"):
1666
- plot_ess(idata, rug=True)
1667
-
1668
-
1669
- @pytest.mark.parametrize(
1670
- "kwargs",
1671
- [
1672
- {},
1673
- {"n_unif": 50, "legend": False},
1674
- {"use_hdi": True, "color": "gray"},
1675
- {"use_hdi": True, "hdi_prob": 0.68},
1676
- {"use_hdi": True, "hdi_kwargs": {"fill": 0.1}},
1677
- {"ecdf": True},
1678
- {"ecdf": True, "ecdf_fill": False, "plot_unif_kwargs": {"ls": "--"}},
1679
- {"ecdf": True, "hdi_prob": 0.97, "fill_kwargs": {"hatch": "/"}},
1680
- ],
1681
- )
1682
- def test_plot_loo_pit(models, kwargs):
1683
- axes = plot_loo_pit(idata=models.model_1, y="y", **kwargs)
1684
- assert axes
1685
-
1686
-
1687
- def test_plot_loo_pit_incompatible_args(models):
1688
- """Test error when both ecdf and use_hdi are True."""
1689
- with pytest.raises(ValueError, match="incompatible"):
1690
- plot_loo_pit(idata=models.model_1, y="y", ecdf=True, use_hdi=True)
1691
-
1692
-
1693
- @pytest.mark.parametrize(
1694
- "kwargs",
1695
- [
1696
- {},
1697
- {"var_names": ["theta"], "color": "r"},
1698
- {"rug": True, "rug_kwargs": {"color": "r"}},
1699
- {"errorbar": True, "rug": True, "rug_kind": "max_depth"},
1700
- {"errorbar": True, "coords": {"school": slice(4)}, "n_points": 10},
1701
- {"extra_methods": True, "rug": True},
1702
- {"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
1703
- ],
1704
- )
1705
- def test_plot_mcse(models, kwargs):
1706
- idata = models.model_1
1707
- ax = plot_mcse(idata, **kwargs)
1708
- assert np.all(ax)
1709
-
1710
-
1711
- @pytest.mark.parametrize("dim", ["chain", "draw"])
1712
- def test_plot_mcse_bad_coords(models, dim):
1713
- """Test error when chain or dim are used as coords to select a data subset."""
1714
- idata = models.model_1
1715
- with pytest.raises(ValueError, match="invalid coordinates"):
1716
- plot_mcse(idata, coords={dim: slice(3)})
1717
-
1718
-
1719
- def test_plot_mcse_no_sample_stats(models):
1720
- """Test error when rug=True but sample_stats group is not present."""
1721
- idata = models.model_1
1722
- with pytest.raises(ValueError, match="must contain sample_stats"):
1723
- plot_mcse(idata.posterior, rug=True)
1724
-
1725
-
1726
- def test_plot_mcse_no_divergences(models):
1727
- """Test error when rug=True, but the variable defined by rug_kind is missing."""
1728
- idata = deepcopy(models.model_1)
1729
- idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
1730
- with pytest.raises(ValueError, match="not contain diverging"):
1731
- plot_mcse(idata, rug=True)
1732
-
1733
-
1734
- @pytest.mark.parametrize(
1735
- "kwargs",
1736
- [
1737
- {},
1738
- {"var_names": ["theta"]},
1739
- {"var_names": ["theta"], "coords": {"school": [0, 1]}},
1740
- {"var_names": ["eta"], "posterior_kwargs": {"rug": True, "rug_kwargs": {"color": "r"}}},
1741
- {"var_names": ["mu"], "prior_kwargs": {"fill_kwargs": {"alpha": 0.5}}},
1742
- {
1743
- "var_names": ["tau"],
1744
- "prior_kwargs": {"plot_kwargs": {"color": "r"}},
1745
- "posterior_kwargs": {"plot_kwargs": {"color": "b"}},
1746
- },
1747
- {"var_names": ["y"], "kind": "observed"},
1748
- ],
1749
- )
1750
- def test_plot_dist_comparison(models, kwargs):
1751
- idata = models.model_1
1752
- ax = plot_dist_comparison(idata, **kwargs)
1753
- assert np.all(ax)
1754
-
1755
-
1756
- def test_plot_dist_comparison_different_vars():
1757
- data = from_dict(
1758
- posterior={
1759
- "x": np.random.randn(4, 100, 30),
1760
- },
1761
- prior={"x_hat": np.random.randn(4, 100, 30)},
1762
- )
1763
- with pytest.raises(KeyError):
1764
- plot_dist_comparison(data, var_names="x")
1765
- ax = plot_dist_comparison(data, var_names=[["x_hat"], ["x"]])
1766
- assert np.all(ax)
1767
-
1768
-
1769
- def test_plot_dist_comparison_combinedims(models):
1770
- idata = models.model_1
1771
- ax = plot_dist_comparison(idata, combine_dims={"school"})
1772
- assert np.all(ax)
1773
-
1774
-
1775
- def test_plot_dist_comparison_different_vars_combinedims():
1776
- data = from_dict(
1777
- posterior={
1778
- "x": np.random.randn(4, 100, 30),
1779
- },
1780
- prior={"x_hat": np.random.randn(4, 100, 30)},
1781
- dims={"x": ["3rd_dim"], "x_hat": ["3rd_dim"]},
1782
- )
1783
- with pytest.raises(KeyError):
1784
- plot_dist_comparison(data, var_names="x", combine_dims={"3rd_dim"})
1785
- ax = plot_dist_comparison(data, var_names=[["x_hat"], ["x"]], combine_dims={"3rd_dim"})
1786
- assert np.all(ax)
1787
- assert ax.size == 3
1788
-
1789
-
1790
- @pytest.mark.parametrize(
1791
- "kwargs",
1792
- [
1793
- {},
1794
- {"reference": "analytical"},
1795
- {"kind": "p_value"},
1796
- {"kind": "t_stat", "t_stat": "std"},
1797
- {"kind": "t_stat", "t_stat": 0.5, "bpv": True},
1798
- ],
1799
- )
1800
- def test_plot_bpv(models, kwargs):
1801
- axes = plot_bpv(models.model_1, **kwargs)
1802
- assert not isinstance(axes, np.ndarray)
1803
-
1804
-
1805
- def test_plot_bpv_discrete():
1806
- fake_obs = {"a": np.random.poisson(2.5, 100)}
1807
- fake_pp = {"a": np.random.poisson(2.5, (1, 10, 100))}
1808
- fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs)
1809
- axes = plot_bpv(fake_model)
1810
- assert not isinstance(axes, np.ndarray)
1811
-
1812
-
1813
- @pytest.mark.parametrize(
1814
- "kwargs",
1815
- [
1816
- {},
1817
- {
1818
- "binwidth": 0.5,
1819
- "stackratio": 2,
1820
- "nquantiles": 20,
1821
- },
1822
- {"point_interval": True},
1823
- {
1824
- "point_interval": True,
1825
- "dotsize": 1.2,
1826
- "point_estimate": "median",
1827
- "plot_kwargs": {"color": "grey"},
1828
- },
1829
- {
1830
- "point_interval": True,
1831
- "plot_kwargs": {"color": "grey"},
1832
- "nquantiles": 100,
1833
- "hdi_prob": 0.95,
1834
- "intervalcolor": "green",
1835
- },
1836
- {
1837
- "point_interval": True,
1838
- "plot_kwargs": {"color": "grey"},
1839
- "quartiles": False,
1840
- "linewidth": 2,
1841
- },
1842
- ],
1843
- )
1844
- def test_plot_dot(continuous_model, kwargs):
1845
- data = continuous_model["x"]
1846
- ax = plot_dot(data, **kwargs)
1847
- assert ax
1848
-
1849
-
1850
- @pytest.mark.parametrize(
1851
- "kwargs",
1852
- [
1853
- {"rotated": True},
1854
- {
1855
- "point_interval": True,
1856
- "rotated": True,
1857
- "dotcolor": "grey",
1858
- "binwidth": 0.5,
1859
- },
1860
- {
1861
- "rotated": True,
1862
- "point_interval": True,
1863
- "plot_kwargs": {"color": "grey"},
1864
- "nquantiles": 100,
1865
- "dotsize": 0.8,
1866
- "hdi_prob": 0.95,
1867
- "intervalcolor": "green",
1868
- },
1869
- ],
1870
- )
1871
- def test_plot_dot_rotated(continuous_model, kwargs):
1872
- data = continuous_model["x"]
1873
- ax = plot_dot(data, **kwargs)
1874
- assert ax
1875
-
1876
-
1877
- @pytest.mark.parametrize(
1878
- "kwargs",
1879
- [
1880
- {
1881
- "point_estimate": "mean",
1882
- "hdi_prob": 0.95,
1883
- "quartiles": False,
1884
- "linewidth": 2,
1885
- "markersize": 2,
1886
- "markercolor": "red",
1887
- "marker": "o",
1888
- "rotated": False,
1889
- "intervalcolor": "green",
1890
- },
1891
- ],
1892
- )
1893
- def test_plot_point_interval(continuous_model, kwargs):
1894
- _, ax = plt.subplots()
1895
- data = continuous_model["x"]
1896
- values = np.sort(data)
1897
- ax = plot_point_interval(ax, values, **kwargs)
1898
- assert ax
1899
-
1900
-
1901
- def test_wilkinson_algorithm(continuous_model):
1902
- data = continuous_model["x"]
1903
- values = np.sort(data)
1904
- _, stack_counts = wilkinson_algorithm(values, 0.5)
1905
- assert np.sum(stack_counts) == len(values)
1906
- stack_locs, stack_counts = wilkinson_algorithm([0.0, 1.0, 1.8, 3.0, 5.0], 1.0)
1907
- assert stack_locs == [0.0, 1.4, 3.0, 5.0]
1908
- assert stack_counts == [1, 2, 1, 1]
1909
-
1910
-
1911
- @pytest.mark.parametrize(
1912
- "kwargs",
1913
- [
1914
- {},
1915
- {"y_hat": "bad_name"},
1916
- {"x": "x1"},
1917
- {"x": ("x1", "x2")},
1918
- {
1919
- "x": ("x1", "x2"),
1920
- "y_kwargs": {"color": "blue", "marker": "^"},
1921
- "y_hat_plot_kwargs": {"color": "cyan"},
1922
- },
1923
- {"x": ("x1", "x2"), "y_model_plot_kwargs": {"color": "red"}},
1924
- {
1925
- "x": ("x1", "x2"),
1926
- "kind_pp": "hdi",
1927
- "kind_model": "hdi",
1928
- "y_model_fill_kwargs": {"color": "red"},
1929
- "y_hat_fill_kwargs": {"color": "cyan"},
1930
- },
1931
- ],
1932
- )
1933
- def test_plot_lm_1d(models, kwargs):
1934
- """Test functionality for 1D data."""
1935
- idata = models.model_1
1936
- if "constant_data" not in idata.groups():
1937
- y = idata.observed_data["y"]
1938
- x1data = y.coords[y.dims[0]]
1939
- idata.add_groups({"constant_data": {"_": x1data}})
1940
- idata.constant_data["x1"] = x1data
1941
- idata.constant_data["x2"] = x1data
1942
-
1943
- axes = plot_lm(idata=idata, y="y", y_model="eta", xjitter=True, **kwargs)
1944
- assert np.all(axes)
1945
-
1946
-
1947
- def test_plot_lm_multidim(multidim_models):
1948
- """Test functionality for multidimentional data."""
1949
- idata = multidim_models.model_1
1950
- axes = plot_lm(
1951
- idata=idata,
1952
- x=idata.observed_data["y"].coords["dim1"].values,
1953
- y="y",
1954
- xjitter=True,
1955
- plot_dim="dim1",
1956
- show=False,
1957
- figsize=(4, 16),
1958
- )
1959
- assert np.all(axes)
1960
-
1961
-
1962
- @pytest.mark.parametrize(
1963
- "val_err_kwargs",
1964
- [
1965
- {},
1966
- {"kind_pp": "bad_kind"},
1967
- {"kind_model": "bad_kind"},
1968
- ],
1969
- )
1970
- def test_plot_lm_valueerror(multidim_models, val_err_kwargs):
1971
- """Test error plot_dim gets no value for multidim data and wrong value in kind_... args."""
1972
- idata2 = multidim_models.model_1
1973
- with pytest.raises(ValueError):
1974
- plot_lm(idata=idata2, y="y", **val_err_kwargs)
1975
-
1976
-
1977
- @pytest.mark.parametrize(
1978
- "warn_kwargs",
1979
- [
1980
- {"y_hat": "bad_name"},
1981
- {"y_model": "bad_name"},
1982
- ],
1983
- )
1984
- def test_plot_lm_warning(models, warn_kwargs):
1985
- """Test Warning when needed groups or variables are not there in idata."""
1986
- idata1 = models.model_1
1987
- with pytest.warns(UserWarning):
1988
- plot_lm(
1989
- idata=from_dict(observed_data={"y": idata1.observed_data["y"].values}),
1990
- y="y",
1991
- **warn_kwargs,
1992
- )
1993
- with pytest.warns(UserWarning):
1994
- plot_lm(idata=idata1, y="y", **warn_kwargs)
1995
-
1996
-
1997
- def test_plot_lm_typeerror(models):
1998
- """Test error when invalid value passed to num_samples."""
1999
- idata1 = models.model_1
2000
- with pytest.raises(TypeError):
2001
- plot_lm(idata=idata1, y="y", num_samples=-1)
2002
-
2003
-
2004
- def test_plot_lm_list():
2005
- """Test the plots when input data is list or ndarray."""
2006
- y = [1, 2, 3, 4, 5]
2007
- assert plot_lm(y=y, x=np.arange(len(y)), show=False)
2008
-
2009
-
2010
- @pytest.mark.parametrize(
2011
- "kwargs",
2012
- [
2013
- {},
2014
- {"y_hat": "bad_name"},
2015
- {"x": "x"},
2016
- {"x": ("x", "x")},
2017
- {"y_holdout": "z"},
2018
- {"y_holdout": "z", "x_holdout": "x_pred"},
2019
- {"x": ("x", "x"), "y_holdout": "z", "x_holdout": ("x_pred", "x_pred")},
2020
- {"y_forecasts": "z"},
2021
- {"y_holdout": "z", "y_forecasts": "bad_name"},
2022
- ],
2023
- )
2024
- def test_plot_ts(kwargs):
2025
- """Test timeseries plots basic functionality."""
2026
- nchains = 4
2027
- ndraws = 500
2028
- obs_data = {
2029
- "y": 2 * np.arange(1, 9) + 3,
2030
- "z": 2 * np.arange(8, 12) + 3,
2031
- }
2032
-
2033
- posterior_predictive = {
2034
- "y": np.random.normal(
2035
- (obs_data["y"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["y"]))
2036
- ),
2037
- "z": np.random.normal(
2038
- (obs_data["z"] * 1.2) - 3, size=(nchains, ndraws, len(obs_data["z"]))
2039
- ),
2040
- }
2041
-
2042
- const_data = {"x": np.arange(1, 9), "x_pred": np.arange(8, 12)}
2043
-
2044
- idata = from_dict(
2045
- observed_data=obs_data,
2046
- posterior_predictive=posterior_predictive,
2047
- constant_data=const_data,
2048
- coords={"obs_dim": np.arange(1, 9), "pred_dim": np.arange(8, 12)},
2049
- dims={"y": ["obs_dim"], "z": ["pred_dim"]},
2050
- )
2051
-
2052
- ax = plot_ts(idata=idata, y="y", **kwargs)
2053
- assert np.all(ax)
2054
-
2055
-
2056
- @pytest.mark.parametrize(
2057
- "kwargs",
2058
- [
2059
- {},
2060
- {
2061
- "y_holdout": "z",
2062
- "holdout_dim": "holdout_dim1",
2063
- "x": ("x", "x"),
2064
- "x_holdout": ("x_pred", "x_pred"),
2065
- },
2066
- {"y_forecasts": "z", "holdout_dim": "holdout_dim1"},
2067
- ],
2068
- )
2069
- def test_plot_ts_multidim(kwargs):
2070
- """Test timeseries plots multidim functionality."""
2071
- nchains = 4
2072
- ndraws = 500
2073
- ndim1 = 5
2074
- ndim2 = 7
2075
- data = {
2076
- "y": np.random.normal(size=(ndim1, ndim2)),
2077
- "z": np.random.normal(size=(ndim1, ndim2)),
2078
- }
2079
-
2080
- posterior_predictive = {
2081
- "y": np.random.randn(nchains, ndraws, ndim1, ndim2),
2082
- "z": np.random.randn(nchains, ndraws, ndim1, ndim2),
2083
- }
2084
-
2085
- const_data = {"x": np.arange(1, 6), "x_pred": np.arange(5, 10)}
2086
-
2087
- idata = from_dict(
2088
- observed_data=data,
2089
- posterior_predictive=posterior_predictive,
2090
- constant_data=const_data,
2091
- dims={
2092
- "y": ["dim1", "dim2"],
2093
- "z": ["holdout_dim1", "holdout_dim2"],
2094
- },
2095
- coords={
2096
- "dim1": range(ndim1),
2097
- "dim2": range(ndim2),
2098
- "holdout_dim1": range(ndim1 - 1, ndim1 + 4),
2099
- "holdout_dim2": range(ndim2 - 1, ndim2 + 6),
2100
- },
2101
- )
2102
-
2103
- ax = plot_ts(idata=idata, y="y", plot_dim="dim1", **kwargs)
2104
- assert np.all(ax)
2105
-
2106
-
2107
- @pytest.mark.parametrize("val_err_kwargs", [{}, {"plot_dim": "dim1", "y_holdout": "y"}])
2108
- def test_plot_ts_valueerror(multidim_models, val_err_kwargs):
2109
- """Test error plot_dim gets no value for multidim data and wrong value in kind_... args."""
2110
- idata2 = multidim_models.model_1
2111
- with pytest.raises(ValueError):
2112
- plot_ts(idata=idata2, y="y", **val_err_kwargs)
2113
-
2114
-
2115
- def test_plot_bf():
2116
- idata = from_dict(
2117
- posterior={"a": np.random.normal(1, 0.5, 5000)}, prior={"a": np.random.normal(0, 1, 5000)}
2118
- )
2119
- _, bf_plot = plot_bf(idata, var_name="a", ref_val=0)
2120
- assert bf_plot is not None
2121
-
2122
-
2123
- def generate_lm_1d_data():
2124
- rng = np.random.default_rng()
2125
- return from_dict(
2126
- observed_data={"y": rng.normal(size=7)},
2127
- posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
2128
- posterior={"y_model": rng.normal(size=(4, 1000, 7))},
2129
- dims={"y": ["dim1"]},
2130
- coords={"dim1": range(7)},
2131
- )
2132
-
2133
-
2134
- def generate_lm_2d_data():
2135
- rng = np.random.default_rng()
2136
- return from_dict(
2137
- observed_data={"y": rng.normal(size=(5, 7))},
2138
- posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
2139
- posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
2140
- dims={"y": ["dim1", "dim2"]},
2141
- coords={"dim1": range(5), "dim2": range(7)},
2142
- )
2143
-
2144
-
2145
- @pytest.mark.parametrize("data", ("1d", "2d"))
2146
- @pytest.mark.parametrize("kind", ("lines", "hdi"))
2147
- @pytest.mark.parametrize("use_y_model", (True, False))
2148
- def test_plot_lm(data, kind, use_y_model):
2149
- if data == "1d":
2150
- idata = generate_lm_1d_data()
2151
- else:
2152
- idata = generate_lm_2d_data()
2153
-
2154
- kwargs = {"idata": idata, "y": "y", "kind_model": kind}
2155
- if data == "2d":
2156
- kwargs["plot_dim"] = "dim1"
2157
- if use_y_model:
2158
- kwargs["y_model"] = "y_model"
2159
- if kind == "lines":
2160
- kwargs["num_samples"] = 50
2161
-
2162
- ax = plot_lm(**kwargs)
2163
- assert ax is not None
2164
-
2165
-
2166
- @pytest.mark.parametrize(
2167
- "coords, expected_vars",
2168
- [
2169
- ({"school": ["Choate"]}, ["theta"]),
2170
- ({"school": ["Lawrenceville"]}, ["theta"]),
2171
- ({}, ["theta"]),
2172
- ],
2173
- )
2174
- def test_plot_autocorr_coords(coords, expected_vars):
2175
- """Test plot_autocorr with coords kwarg."""
2176
- idata = load_arviz_data("centered_eight")
2177
-
2178
- axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
2179
- assert axes is not None
2180
-
2181
-
2182
- def test_plot_forest_with_transform():
2183
- """Test if plot_forest runs successfully with a transform dictionary."""
2184
- data = xr.Dataset(
2185
- {
2186
- "var1": (["chain", "draw"], np.array([[1, 2, 3], [4, 5, 6]])),
2187
- "var2": (["chain", "draw"], np.array([[7, 8, 9], [10, 11, 12]])),
2188
- },
2189
- coords={"chain": [0, 1], "draw": [0, 1, 2]},
2190
- )
2191
- transform_dict = {
2192
- "var1": lambda x: x + 1,
2193
- "var2": lambda x: x * 2,
2194
- }
2195
-
2196
- axes = plot_forest(data, transform=transform_dict, show=False)
2197
- assert axes is not None