arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,1288 +0,0 @@
1
- # pylint: disable=redefined-outer-name,too-many-lines
2
- """Tests use the 'bokeh' backend."""
3
- from copy import deepcopy
4
-
5
- import numpy as np
6
- import pytest
7
- from pandas import DataFrame # pylint: disable=wrong-import-position
8
- from scipy.stats import norm # pylint: disable=wrong-import-position
9
-
10
- from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
11
- from ...labels import MapLabeller # pylint: disable=wrong-import-position
12
- from ...plots import ( # pylint: disable=wrong-import-position
13
- plot_autocorr,
14
- plot_bpv,
15
- plot_compare,
16
- plot_density,
17
- plot_dist,
18
- plot_dist_comparison,
19
- plot_dot,
20
- plot_ecdf,
21
- plot_elpd,
22
- plot_energy,
23
- plot_ess,
24
- plot_forest,
25
- plot_hdi,
26
- plot_kde,
27
- plot_khat,
28
- plot_lm,
29
- plot_loo_pit,
30
- plot_mcse,
31
- plot_pair,
32
- plot_parallel,
33
- plot_posterior,
34
- plot_ppc,
35
- plot_rank,
36
- plot_separation,
37
- plot_trace,
38
- plot_violin,
39
- )
40
- from ...rcparams import rc_context, rcParams # pylint: disable=wrong-import-position
41
- from ...stats import compare, hdi, loo, waic # pylint: disable=wrong-import-position
42
- from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
43
- create_model,
44
- create_multidimensional_model,
45
- eight_schools_params,
46
- importorskip,
47
- models,
48
- multidim_models,
49
- )
50
-
51
- # Skip tests if bokeh not installed
52
- bkp = importorskip("bokeh.plotting") # pylint: disable=invalid-name
53
-
54
-
55
- rcParams["data.load"] = "eager"
56
-
57
-
58
- @pytest.fixture(scope="module")
59
- def data(eight_schools_params):
60
- data = eight_schools_params
61
- return data
62
-
63
-
64
- @pytest.fixture(scope="module")
65
- def df_trace():
66
- return DataFrame({"a": np.random.poisson(2.3, 100)})
67
-
68
-
69
- @pytest.fixture(scope="module")
70
- def discrete_model():
71
- """Simple fixture for random discrete model"""
72
- return {"x": np.random.randint(10, size=100), "y": np.random.randint(10, size=100)}
73
-
74
-
75
- @pytest.fixture(scope="module")
76
- def continuous_model():
77
- """Simple fixture for random continuous model"""
78
- return {"x": np.random.beta(2, 5, size=100), "y": np.random.beta(2, 5, size=100)}
79
-
80
-
81
- @pytest.mark.parametrize(
82
- "kwargs",
83
- [
84
- {"point_estimate": "mean"},
85
- {"point_estimate": "median"},
86
- {"hdi_prob": 0.94},
87
- {"hdi_prob": 1},
88
- {"outline": True},
89
- {"hdi_markers": ["v"]},
90
- {"shade": 1},
91
- ],
92
- )
93
- def test_plot_density_float(models, kwargs):
94
- obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
95
- axes = plot_density(obj, backend="bokeh", show=False, **kwargs)
96
- assert axes.shape[0] >= 6
97
- assert axes.shape[0] >= 3
98
-
99
-
100
- def test_plot_density_discrete(discrete_model):
101
- axes = plot_density(discrete_model, shade=0.9, backend="bokeh", show=False)
102
- assert axes.shape[0] == 1
103
-
104
-
105
- def test_plot_density_no_subset():
106
- """Test plot_density works when variables are not subset of one another (#1093)."""
107
- model_ab = from_dict(
108
- {
109
- "a": np.random.normal(size=200),
110
- "b": np.random.normal(size=200),
111
- }
112
- )
113
- model_bc = from_dict(
114
- {
115
- "b": np.random.normal(size=200),
116
- "c": np.random.normal(size=200),
117
- }
118
- )
119
- axes = plot_density([model_ab, model_bc], backend="bokeh", show=False)
120
- assert axes.size == 3
121
-
122
-
123
- def test_plot_density_one_var():
124
- """Test plot_density works when there is only one variable (#1401)."""
125
- model_ab = from_dict(
126
- {
127
- "a": np.random.normal(size=200),
128
- }
129
- )
130
- model_bc = from_dict(
131
- {
132
- "a": np.random.normal(size=200),
133
- }
134
- )
135
- axes = plot_density([model_ab, model_bc], backend="bokeh", show=False)
136
- assert axes.size == 1
137
-
138
-
139
- def test_plot_density_bad_kwargs(models):
140
- obj = [getattr(models, model_fit) for model_fit in ["model_1", "model_2"]]
141
- with pytest.raises(ValueError):
142
- plot_density(obj, point_estimate="bad_value", backend="bokeh", show=False)
143
-
144
- with pytest.raises(ValueError):
145
- plot_density(
146
- obj,
147
- data_labels=[f"bad_value_{i}" for i in range(len(obj) + 10)],
148
- backend="bokeh",
149
- show=False,
150
- )
151
-
152
- with pytest.raises(ValueError):
153
- plot_density(obj, hdi_prob=2, backend="bokeh", show=False)
154
-
155
-
156
- @pytest.mark.parametrize(
157
- "kwargs",
158
- [
159
- {},
160
- {"y_hat_line": True},
161
- {"expected_events": True},
162
- {"y_hat_line_kwargs": {"linestyle": "dotted"}},
163
- {"exp_events_kwargs": {"marker": "o"}},
164
- ],
165
- )
166
- def test_plot_separation(kwargs):
167
- idata = load_arviz_data("classification10d")
168
- ax = plot_separation(idata=idata, y="outcome", backend="bokeh", show=False, **kwargs)
169
- assert ax
170
-
171
-
172
- @pytest.mark.parametrize(
173
- "kwargs",
174
- [
175
- {},
176
- {"var_names": "mu"},
177
- {"var_names": ["mu", "tau"]},
178
- {"combined": True, "rug": True},
179
- {"compact": True, "legend": True},
180
- {"combined": True, "compact": True, "legend": True},
181
- {"divergences": "top"},
182
- {"divergences": False},
183
- {"kind": "rank_vlines"},
184
- {"kind": "rank_bars"},
185
- {"lines": [("mu", {}, [1, 2])]},
186
- {"lines": [("mu", {}, 8)]},
187
- ],
188
- )
189
- def test_plot_trace(models, kwargs):
190
- axes = plot_trace(models.model_1, backend="bokeh", show=False, **kwargs)
191
- assert axes.shape
192
-
193
-
194
- def test_plot_trace_discrete(discrete_model):
195
- axes = plot_trace(discrete_model, backend="bokeh", show=False)
196
- assert axes.shape
197
-
198
-
199
- def test_plot_trace_max_subplots_warning(models):
200
- with pytest.warns(UserWarning):
201
- with rc_context(rc={"plot.max_subplots": 2}):
202
- axes = plot_trace(models.model_1, backend="bokeh", show=False)
203
- assert axes.shape
204
-
205
-
206
- @pytest.mark.parametrize(
207
- "kwargs",
208
- [
209
- {"plot_kwargs": {"line_dash": "solid"}},
210
- {"contour": True, "fill_last": False},
211
- {
212
- "contour": True,
213
- "contourf_kwargs": {"cmap": "plasma"},
214
- "contour_kwargs": {"line_width": 1},
215
- },
216
- {"contour": False},
217
- {"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}},
218
- ],
219
- )
220
- def test_plot_kde(continuous_model, kwargs):
221
- axes = plot_kde(
222
- continuous_model["x"], continuous_model["y"], backend="bokeh", show=False, **kwargs
223
- )
224
- assert axes
225
-
226
-
227
- @pytest.mark.parametrize(
228
- "kwargs",
229
- [
230
- {"cumulative": True},
231
- {"cumulative": True, "plot_kwargs": {"line_dash": "dashed"}},
232
- {"rug": True},
233
- {"rug": True, "rug_kwargs": {"line_alpha": 0.2}, "rotated": True},
234
- ],
235
- )
236
- def test_plot_kde_cumulative(continuous_model, kwargs):
237
- axes = plot_kde(continuous_model["x"], backend="bokeh", show=False, **kwargs)
238
- assert axes
239
-
240
-
241
- @pytest.mark.parametrize(
242
- "kwargs",
243
- [
244
- {"kind": "hist"},
245
- {"kind": "kde"},
246
- {"is_circular": False},
247
- {"is_circular": False, "kind": "hist"},
248
- {"is_circular": True},
249
- {"is_circular": True, "kind": "hist"},
250
- {"is_circular": "radians"},
251
- {"is_circular": "radians", "kind": "hist"},
252
- {"is_circular": "degrees"},
253
- {"is_circular": "degrees", "kind": "hist"},
254
- ],
255
- )
256
- def test_plot_dist(continuous_model, kwargs):
257
- axes = plot_dist(continuous_model["x"], backend="bokeh", show=False, **kwargs)
258
- assert axes
259
-
260
-
261
- def test_plot_kde_1d(continuous_model):
262
- axes = plot_kde(continuous_model["y"], backend="bokeh", show=False)
263
- assert axes
264
-
265
-
266
- @pytest.mark.parametrize(
267
- "kwargs",
268
- [
269
- {"contour": True, "fill_last": False},
270
- {"contour": True, "contourf_kwargs": {"cmap": "plasma"}},
271
- {"contour": False},
272
- {"contour": False, "pcolormesh_kwargs": {"cmap": "plasma"}},
273
- {"contour": True, "contourf_kwargs": {"levels": 3}},
274
- {"contour": True, "contourf_kwargs": {"levels": [0.1, 0.2, 0.3]}},
275
- {"hdi_probs": [0.3, 0.9, 0.6]},
276
- {"hdi_probs": [0.3, 0.6, 0.9], "contourf_kwargs": {"cmap": "Blues"}},
277
- {"hdi_probs": [0.9, 0.6, 0.3], "contour_kwargs": {"alpha": 0}},
278
- ],
279
- )
280
- def test_plot_kde_2d(continuous_model, kwargs):
281
- axes = plot_kde(
282
- continuous_model["x"], continuous_model["y"], backend="bokeh", show=False, **kwargs
283
- )
284
- assert axes
285
-
286
-
287
- @pytest.mark.parametrize(
288
- "kwargs", [{"plot_kwargs": {"line_dash": "solid"}}, {"cumulative": True}, {"rug": True}]
289
- )
290
- def test_plot_kde_quantiles(continuous_model, kwargs):
291
- axes = plot_kde(
292
- continuous_model["x"], quantiles=[0.05, 0.5, 0.95], backend="bokeh", show=False, **kwargs
293
- )
294
- assert axes
295
-
296
-
297
- def test_plot_autocorr_short_chain():
298
- """Check that logic for small chain defaulting doesn't cause exception"""
299
- chain = np.arange(10)
300
- axes = plot_autocorr(chain, backend="bokeh", show=False)
301
- assert axes
302
-
303
-
304
- def test_plot_autocorr_uncombined(models):
305
- axes = plot_autocorr(models.model_1, combined=False, backend="bokeh", show=False)
306
- assert axes.shape[0] == 10
307
- max_subplots = (
308
- np.inf if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
309
- )
310
- assert len([ax for ax in axes.ravel() if ax is not None]) == min(72, max_subplots)
311
-
312
-
313
- def test_plot_autocorr_combined(models):
314
- axes = plot_autocorr(models.model_1, combined=True, backend="bokeh", show=False)
315
- assert axes.shape[0] == 6
316
- assert axes.shape[1] == 3
317
-
318
-
319
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
320
- def test_plot_autocorr_var_names(models, var_names):
321
- axes = plot_autocorr(
322
- models.model_1, var_names=var_names, combined=True, backend="bokeh", show=False
323
- )
324
- assert axes.shape
325
-
326
-
327
- @pytest.mark.parametrize(
328
- "kwargs", [{"insample_dev": False}, {"plot_standard_error": False}, {"plot_ic_diff": False}]
329
- )
330
- def test_plot_compare(models, kwargs):
331
- model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
332
-
333
- axes = plot_compare(model_compare, backend="bokeh", show=False, **kwargs)
334
- assert axes
335
-
336
-
337
- def test_plot_compare_no_ic(models):
338
- """Check exception is raised if model_compare doesn't contain a valid information criterion"""
339
- model_compare = compare({"Model 1": models.model_1, "Model 2": models.model_2})
340
-
341
- # Drop column needed for plotting
342
- model_compare = model_compare.drop("elpd_loo", axis=1)
343
- with pytest.raises(ValueError) as err:
344
- plot_compare(model_compare, backend="bokeh", show=False)
345
-
346
- assert "comp_df must contain one of the following" in str(err.value)
347
- assert "['elpd_loo', 'elpd_waic']" in str(err.value)
348
-
349
-
350
- def test_plot_ecdf_basic():
351
- data = np.random.randn(4, 1000)
352
- axes = plot_ecdf(data, backend="bokeh", show=False)
353
- assert axes is not None
354
-
355
-
356
- def test_plot_ecdf_values2():
357
- data = np.random.randn(4, 1000)
358
- data2 = np.random.randn(4, 500)
359
- axes = plot_ecdf(data, data2, backend="bokeh", show=False)
360
- assert axes is not None
361
-
362
-
363
- def test_plot_ecdf_cdf():
364
- data = np.random.randn(4, 1000)
365
- cdf = norm(0, 1).cdf
366
- axes = plot_ecdf(data, cdf=cdf, backend="bokeh", show=False)
367
- assert axes is not None
368
-
369
-
370
- @pytest.mark.parametrize(
371
- "kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}, {"threshold": 2}]
372
- )
373
- @pytest.mark.parametrize("add_model", [False, True])
374
- @pytest.mark.parametrize("use_elpddata", [False, True])
375
- def test_plot_elpd(models, add_model, use_elpddata, kwargs):
376
- model_dict = {"Model 1": models.model_1, "Model 2": models.model_2}
377
- if add_model:
378
- model_dict["Model 3"] = create_model(seed=12)
379
-
380
- if use_elpddata:
381
- ic = kwargs.get("ic", "waic")
382
- scale = kwargs.get("scale", "deviance")
383
- if ic == "waic":
384
- model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
385
- else:
386
- model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
387
-
388
- axes = plot_elpd(model_dict, backend="bokeh", show=False, **kwargs)
389
- assert np.any(axes)
390
- if add_model:
391
- assert axes.shape[0] == axes.shape[1]
392
- assert axes.shape[0] == len(model_dict) - 1
393
-
394
-
395
- @pytest.mark.parametrize("kwargs", [{}, {"ic": "loo"}, {"xlabels": True, "scale": "log"}])
396
- @pytest.mark.parametrize("add_model", [False, True])
397
- @pytest.mark.parametrize("use_elpddata", [False, True])
398
- def test_plot_elpd_multidim(multidim_models, add_model, use_elpddata, kwargs):
399
- model_dict = {"Model 1": multidim_models.model_1, "Model 2": multidim_models.model_2}
400
- if add_model:
401
- model_dict["Model 3"] = create_multidimensional_model(seed=12)
402
-
403
- if use_elpddata:
404
- ic = kwargs.get("ic", "waic")
405
- scale = kwargs.get("scale", "deviance")
406
- if ic == "waic":
407
- model_dict = {k: waic(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
408
- else:
409
- model_dict = {k: loo(v, scale=scale, pointwise=True) for k, v in model_dict.items()}
410
-
411
- axes = plot_elpd(model_dict, backend="bokeh", show=False, **kwargs)
412
- assert np.any(axes)
413
- if add_model:
414
- assert axes.shape[0] == axes.shape[1]
415
- assert axes.shape[0] == len(model_dict) - 1
416
-
417
-
418
- @pytest.mark.parametrize("kind", ["kde", "hist"])
419
- def test_plot_energy(models, kind):
420
- assert plot_energy(models.model_1, kind=kind, backend="bokeh", show=False)
421
-
422
-
423
- def test_plot_energy_bad(models):
424
- with pytest.raises(ValueError):
425
- plot_energy(models.model_1, kind="bad_kind", backend="bokeh", show=False)
426
-
427
-
428
- @pytest.mark.parametrize(
429
- "kwargs",
430
- [
431
- {},
432
- {"var_names": ["theta"], "relative": True, "color": "r"},
433
- {"coords": {"school": slice(4)}, "n_points": 10},
434
- {"min_ess": 600, "hline_kwargs": {"line_color": "red"}},
435
- ],
436
- )
437
- @pytest.mark.parametrize("kind", ["local", "quantile", "evolution"])
438
- def test_plot_ess(models, kind, kwargs):
439
- """Test plot_ess arguments common to all kind of plots."""
440
- idata = models.model_1
441
- ax = plot_ess(idata, kind=kind, backend="bokeh", show=False, **kwargs)
442
- assert np.all(ax)
443
-
444
-
445
- @pytest.mark.parametrize(
446
- "kwargs",
447
- [
448
- {"rug": True},
449
- {"rug": True, "rug_kind": "max_depth", "rug_kwargs": {"color": "c"}},
450
- {"extra_methods": True},
451
- {"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
452
- {"extra_methods": True, "rug": True},
453
- ],
454
- )
455
- @pytest.mark.parametrize("kind", ["local", "quantile"])
456
- def test_plot_ess_local_quantile(models, kind, kwargs):
457
- """Test specific arguments in kinds local and quantile of plot_ess."""
458
- idata = models.model_1
459
- ax = plot_ess(idata, kind=kind, backend="bokeh", show=False, **kwargs)
460
- assert np.all(ax)
461
-
462
-
463
- def test_plot_ess_evolution(models):
464
- """Test specific arguments in evolution kind of plot_ess."""
465
- idata = models.model_1
466
- ax = plot_ess(
467
- idata,
468
- kind="evolution",
469
- extra_kwargs={"linestyle": "--"},
470
- color="b",
471
- backend="bokeh",
472
- show=False,
473
- )
474
- assert np.all(ax)
475
-
476
-
477
- def test_plot_ess_bad_kind(models):
478
- """Test error when plot_ess receives an invalid kind."""
479
- idata = models.model_1
480
- with pytest.raises(ValueError, match="Invalid kind"):
481
- plot_ess(idata, kind="bad kind", backend="bokeh", show=False)
482
-
483
-
484
- @pytest.mark.parametrize("dim", ["chain", "draw"])
485
- def test_plot_ess_bad_coords(models, dim):
486
- """Test error when chain or dim are used as coords to select a data subset."""
487
- idata = models.model_1
488
- with pytest.raises(ValueError, match="invalid coordinates"):
489
- plot_ess(idata, coords={dim: slice(3)}, backend="bokeh", show=False)
490
-
491
-
492
- def test_plot_ess_no_sample_stats(models):
493
- """Test error when rug=True but sample_stats group is not present."""
494
- idata = models.model_1
495
- with pytest.raises(ValueError, match="must contain sample_stats"):
496
- plot_ess(idata.posterior, rug=True, backend="bokeh", show=False)
497
-
498
-
499
- def test_plot_ess_no_divergences(models):
500
- """Test error when rug=True, but the variable defined by rug_kind is missing."""
501
- idata = deepcopy(models.model_1)
502
- idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
503
- with pytest.raises(ValueError, match="not contain diverging"):
504
- plot_ess(idata, rug=True, backend="bokeh", show=False)
505
-
506
-
507
- @pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
508
- @pytest.mark.parametrize(
509
- "args_expected",
510
- [
511
- ({}, 1),
512
- ({"var_names": "mu"}, 1),
513
- ({"var_names": "mu", "rope": (-1, 1)}, 1),
514
- ({"r_hat": True, "quartiles": False}, 2),
515
- ({"var_names": ["mu"], "colors": "black", "ess": True, "combined": True}, 2),
516
- (
517
- {
518
- "kind": "ridgeplot",
519
- "ridgeplot_truncate": False,
520
- "ridgeplot_quantiles": [0.25, 0.5, 0.75],
521
- },
522
- 1,
523
- ),
524
- ({"kind": "ridgeplot", "r_hat": True, "ess": True}, 3),
525
- ({"kind": "ridgeplot", "r_hat": True, "ess": True, "ridgeplot_alpha": 0}, 3),
526
- (
527
- {
528
- "var_names": ["mu", "tau"],
529
- "rope": {
530
- "mu": [{"rope": (-0.1, 0.1)}],
531
- "theta": [{"school": "Choate", "rope": (0.2, 0.5)}],
532
- },
533
- },
534
- 1,
535
- ),
536
- ],
537
- )
538
- def test_plot_forest(models, model_fits, args_expected):
539
- obj = [getattr(models, model_fit) for model_fit in model_fits]
540
- args, expected = args_expected
541
- axes = plot_forest(obj, backend="bokeh", show=False, **args)
542
- assert axes.shape == (1, expected)
543
-
544
-
545
- def test_plot_forest_rope_exception():
546
- with pytest.raises(ValueError) as err:
547
- plot_forest({"x": [1]}, rope="not_correct_format", backend="bokeh", show=False)
548
- assert "Argument `rope` must be None, a dictionary like" in str(err.value)
549
-
550
-
551
- def test_plot_forest_single_value():
552
- axes = plot_forest({"x": [1]}, backend="bokeh", show=False)
553
- assert axes.shape
554
-
555
-
556
- @pytest.mark.parametrize("model_fits", [["model_1"], ["model_1", "model_2"]])
557
- def test_plot_forest_bad(models, model_fits):
558
- obj = [getattr(models, model_fit) for model_fit in model_fits]
559
- with pytest.raises(TypeError):
560
- plot_forest(obj, kind="bad_kind", backend="bokeh", show=False)
561
-
562
- with pytest.raises(ValueError):
563
- plot_forest(
564
- obj,
565
- model_names=[f"model_name_{i}" for i in range(len(obj) + 10)],
566
- backend="bokeh",
567
- show=False,
568
- )
569
-
570
-
571
- @pytest.mark.parametrize(
572
- "kwargs",
573
- [
574
- {"color": "C5", "circular": True},
575
- {"hdi_data": True, "fill_kwargs": {"alpha": 0}},
576
- {"plot_kwargs": {"alpha": 0}},
577
- {"smooth_kwargs": {"window_length": 33, "polyorder": 5, "mode": "mirror"}},
578
- {"hdi_data": True, "smooth": False, "color": "xkcd:jade"},
579
- ],
580
- )
581
- def test_plot_hdi(models, data, kwargs):
582
- hdi_data = kwargs.pop("hdi_data", None)
583
- y_data = models.model_1.posterior["theta"]
584
- if hdi_data:
585
- hdi_data = hdi(y_data)
586
- axis = plot_hdi(data["y"], hdi_data=hdi_data, backend="bokeh", show=False, **kwargs)
587
- else:
588
- axis = plot_hdi(data["y"], y_data, backend="bokeh", show=False, **kwargs)
589
- assert axis
590
-
591
-
592
- @pytest.mark.parametrize(
593
- "kwargs",
594
- [
595
- {},
596
- {"xlabels": True},
597
- {"color": "obs_dim", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
598
- {"color": "obs_dim", "legend": True, "hover_label": True},
599
- {"color": "blue", "coords": {"obs_dim": slice(2, 4)}},
600
- {"color": np.random.uniform(size=8), "show_bins": True},
601
- {
602
- "color": np.random.uniform(size=(8, 3)),
603
- "show_bins": True,
604
- "show_hlines": True,
605
- "threshold": 1,
606
- },
607
- ],
608
- )
609
- @pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
610
- def test_plot_khat(models, input_type, kwargs):
611
- khats_data = loo(models.model_1, pointwise=True)
612
-
613
- if input_type == "data_array":
614
- khats_data = khats_data.pareto_k
615
- elif input_type == "array":
616
- khats_data = khats_data.pareto_k.values
617
- if "color" in kwargs and isinstance(kwargs["color"], str) and kwargs["color"] == "obs_dim":
618
- kwargs["color"] = None
619
-
620
- axes = plot_khat(khats_data, backend="bokeh", show=False, **kwargs)
621
- assert axes
622
-
623
-
624
- @pytest.mark.parametrize(
625
- "kwargs",
626
- [
627
- {},
628
- {"xlabels": True},
629
- {"color": "dim1", "xlabels": True, "show_bins": True, "bin_format": "{0}"},
630
- {"color": "dim2", "legend": True, "hover_label": True},
631
- {"color": "blue", "coords": {"dim2": slice(2, 4)}},
632
- {"color": np.random.uniform(size=35), "show_bins": True},
633
- {
634
- "color": np.random.uniform(size=(35, 3)),
635
- "show_bins": True,
636
- "show_hlines": True,
637
- "threshold": 1,
638
- },
639
- ],
640
- )
641
- @pytest.mark.parametrize("input_type", ["elpd_data", "data_array", "array"])
642
- def test_plot_khat_multidim(multidim_models, input_type, kwargs):
643
- khats_data = loo(multidim_models.model_1, pointwise=True)
644
-
645
- if input_type == "data_array":
646
- khats_data = khats_data.pareto_k
647
- elif input_type == "array":
648
- khats_data = khats_data.pareto_k.values
649
- if (
650
- "color" in kwargs
651
- and isinstance(kwargs["color"], str)
652
- and kwargs["color"] in ("dim1", "dim2")
653
- ):
654
- kwargs["color"] = None
655
-
656
- axes = plot_khat(khats_data, backend="bokeh", show=False, **kwargs)
657
- assert axes
658
-
659
-
660
- def test_plot_khat_threshold():
661
- khats = np.array([0, 0, 0.6, 0.6, 0.8, 0.9, 0.9, 2, 3, 4, 1.5])
662
- axes = plot_khat(khats, threshold=1, backend="bokeh", show=False)
663
- assert axes
664
-
665
-
666
- def test_plot_khat_bad_input(models):
667
- with pytest.raises(ValueError):
668
- plot_khat(models.model_1.sample_stats, backend="bokeh", show=False)
669
-
670
-
671
- @pytest.mark.parametrize(
672
- "kwargs",
673
- [
674
- {},
675
- {"n_unif": 50},
676
- {"use_hdi": True, "color": "gray"},
677
- {"use_hdi": True, "hdi_prob": 0.68},
678
- {"use_hdi": True, "hdi_kwargs": {"line_dash": "dashed", "alpha": 0}},
679
- {"ecdf": True},
680
- {"ecdf": True, "ecdf_fill": False, "plot_unif_kwargs": {"line_dash": "--"}},
681
- {"ecdf": True, "hdi_prob": 0.97, "fill_kwargs": {"color": "red"}},
682
- ],
683
- )
684
- def test_plot_loo_pit(models, kwargs):
685
- axes = plot_loo_pit(idata=models.model_1, y="y", backend="bokeh", show=False, **kwargs)
686
- assert axes
687
-
688
-
689
- def test_plot_loo_pit_incompatible_args(models):
690
- """Test error when both ecdf and use_hdi are True."""
691
- with pytest.raises(ValueError, match="incompatible"):
692
- plot_loo_pit(
693
- idata=models.model_1, y="y", ecdf=True, use_hdi=True, backend="bokeh", show=False
694
- )
695
-
696
-
697
- @pytest.mark.parametrize(
698
- "args",
699
- [
700
- {"y": "str"},
701
- {"y": "DataArray", "y_hat": "str"},
702
- {"y": "ndarray", "y_hat": "str"},
703
- {"y": "ndarray", "y_hat": "DataArray"},
704
- {"y": "ndarray", "y_hat": "ndarray"},
705
- ],
706
- )
707
- def test_plot_loo_pit_label(models, args):
708
- if args["y"] == "str":
709
- y = "y"
710
- elif args["y"] == "DataArray":
711
- y = models.model_1.observed_data.y
712
- elif args["y"] == "ndarray":
713
- y = models.model_1.observed_data.y.values
714
-
715
- if args.get("y_hat") == "str":
716
- y_hat = "y"
717
- elif args.get("y_hat") == "DataArray":
718
- y_hat = models.model_1.posterior_predictive.y.stack(__sample__=("chain", "draw"))
719
- elif args.get("y_hat") == "ndarray":
720
- y_hat = models.model_1.posterior_predictive.y.stack(__sample__=("chain", "draw")).values
721
- else:
722
- y_hat = None
723
-
724
- ax = plot_loo_pit(idata=models.model_1, y=y, y_hat=y_hat, backend="bokeh", show=False)
725
- assert ax
726
-
727
-
728
- @pytest.mark.parametrize(
729
- "kwargs",
730
- [
731
- {},
732
- {"var_names": ["theta"], "color": "r"},
733
- {"rug": True, "rug_kwargs": {"color": "r"}},
734
- {"errorbar": True, "rug": True, "rug_kind": "max_depth"},
735
- {"errorbar": True, "coords": {"school": slice(4)}, "n_points": 10},
736
- {"extra_methods": True, "rug": True},
737
- {"extra_methods": True, "extra_kwargs": {"ls": ":"}, "text_kwargs": {"x": 0, "ha": "left"}},
738
- ],
739
- )
740
- def test_plot_mcse(models, kwargs):
741
- idata = models.model_1
742
- ax = plot_mcse(idata, backend="bokeh", show=False, **kwargs)
743
- assert np.all(ax)
744
-
745
-
746
- @pytest.mark.parametrize("dim", ["chain", "draw"])
747
- def test_plot_mcse_bad_coords(models, dim):
748
- """Test error when chain or dim are used as coords to select a data subset."""
749
- idata = models.model_1
750
- with pytest.raises(ValueError, match="invalid coordinates"):
751
- plot_mcse(idata, coords={dim: slice(3)}, backend="bokeh", show=False)
752
-
753
-
754
- def test_plot_mcse_no_sample_stats(models):
755
- """Test error when rug=True but sample_stats group is not present."""
756
- idata = models.model_1
757
- with pytest.raises(ValueError, match="must contain sample_stats"):
758
- plot_mcse(idata.posterior, rug=True, backend="bokeh", show=False)
759
-
760
-
761
- def test_plot_mcse_no_divergences(models):
762
- """Test error when rug=True, but the variable defined by rug_kind is missing."""
763
- idata = deepcopy(models.model_1)
764
- idata.sample_stats = idata.sample_stats.rename({"diverging": "diverging_missing"})
765
- with pytest.raises(ValueError, match="not contain diverging"):
766
- plot_mcse(idata, rug=True, backend="bokeh", show=False)
767
-
768
-
769
- @pytest.mark.slow
770
- @pytest.mark.parametrize(
771
- "kwargs",
772
- [
773
- {"var_names": "theta", "divergences": True, "coords": {"school": [0, 1]}},
774
- {"divergences": True, "var_names": ["theta", "mu"]},
775
- {"kind": "kde", "var_names": ["theta"]},
776
- {"kind": "hexbin", "var_names": ["theta"]},
777
- {
778
- "kind": "hexbin",
779
- "var_names": ["theta"],
780
- "coords": {"school": [0, 1]},
781
- "textsize": 20,
782
- },
783
- {
784
- "point_estimate": "mean",
785
- "reference_values": {"mu": 0, "tau": 0},
786
- "reference_values_kwargs": {"line_color": "blue"},
787
- },
788
- {
789
- "var_names": ["mu", "tau"],
790
- "reference_values": {"mu": 0, "tau": 0},
791
- "labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
792
- },
793
- {
794
- "var_names": ["theta"],
795
- "reference_values": {"theta": [0.0] * 8},
796
- "labeller": MapLabeller({"theta": r"$\theta$"}),
797
- },
798
- {
799
- "var_names": ["theta"],
800
- "reference_values": {"theta": np.zeros(8)},
801
- "labeller": MapLabeller({"theta": r"$\theta$"}),
802
- },
803
- ],
804
- )
805
- def test_plot_pair(models, kwargs):
806
- ax = plot_pair(models.model_1, backend="bokeh", show=False, **kwargs)
807
- assert np.any(ax)
808
-
809
-
810
- @pytest.mark.parametrize("kwargs", [{"kind": "scatter"}, {"kind": "kde"}, {"kind": "hexbin"}])
811
- def test_plot_pair_2var(discrete_model, kwargs):
812
- ax = plot_pair(
813
- discrete_model, ax=np.atleast_2d(bkp.figure()), backend="bokeh", show=False, **kwargs
814
- )
815
- assert ax
816
-
817
-
818
- def test_plot_pair_bad(models):
819
- with pytest.raises(ValueError):
820
- plot_pair(models.model_1, kind="bad_kind", backend="bokeh", show=False)
821
- with pytest.raises(Exception):
822
- plot_pair(models.model_1, var_names=["mu"], backend="bokeh", show=False)
823
-
824
-
825
- @pytest.mark.parametrize("has_sample_stats", [True, False])
826
- def test_plot_pair_divergences_warning(has_sample_stats):
827
- data = load_arviz_data("centered_eight")
828
- if has_sample_stats:
829
- # sample_stats present, diverging field missing
830
- data.sample_stats = data.sample_stats.rename({"diverging": "diverging_missing"})
831
- else:
832
- # sample_stats missing
833
- data = data.posterior # pylint: disable=no-member
834
- with pytest.warns(UserWarning):
835
- ax = plot_pair(data, divergences=True, backend="bokeh", show=False)
836
- assert np.any(ax)
837
-
838
-
839
- def test_plot_parallel_raises_valueerror(df_trace): # pylint: disable=invalid-name
840
- with pytest.raises(ValueError):
841
- plot_parallel(df_trace, backend="bokeh", show=False)
842
-
843
-
844
- @pytest.mark.parametrize("norm_method", [None, "normal", "minmax", "rank"])
845
- def test_plot_parallel(models, norm_method):
846
- assert plot_parallel(
847
- models.model_1,
848
- var_names=["mu", "tau"],
849
- norm_method=norm_method,
850
- backend="bokeh",
851
- show=False,
852
- )
853
-
854
-
855
- @pytest.mark.parametrize("var_names", [None, "mu", ["mu", "tau"]])
856
- def test_plot_parallel_exception(models, var_names):
857
- """Ensure that correct exception is raised when one variable is passed."""
858
- with pytest.raises(ValueError):
859
- assert plot_parallel(
860
- models.model_1, var_names=var_names, norm_method="foo", backend="bokeh", show=False
861
- )
862
-
863
-
864
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
865
- @pytest.mark.parametrize("side", ["both", "left", "right"])
866
- @pytest.mark.parametrize("rug", [True])
867
- def test_plot_violin(models, var_names, side, rug):
868
- axes = plot_violin(
869
- models.model_1, var_names=var_names, side=side, rug=rug, backend="bokeh", show=False
870
- )
871
- assert axes.shape
872
-
873
-
874
- def test_plot_violin_ax(models):
875
- ax = bkp.figure()
876
- axes = plot_violin(models.model_1, var_names="mu", ax=ax, backend="bokeh", show=False)
877
- assert axes.shape
878
-
879
-
880
- def test_plot_violin_layout(models):
881
- axes = plot_violin(
882
- models.model_1, var_names=["mu", "tau"], sharey=False, backend="bokeh", show=False
883
- )
884
- assert axes.shape
885
-
886
-
887
- def test_plot_violin_discrete(discrete_model):
888
- axes = plot_violin(discrete_model, backend="bokeh", show=False)
889
- assert axes.shape
890
-
891
-
892
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
893
- @pytest.mark.parametrize("alpha", [None, 0.2, 1])
894
- @pytest.mark.parametrize("observed", [True, False])
895
- @pytest.mark.parametrize("observed_rug", [False, True])
896
- def test_plot_ppc(models, kind, alpha, observed, observed_rug):
897
- axes = plot_ppc(
898
- models.model_1,
899
- kind=kind,
900
- alpha=alpha,
901
- observed=observed,
902
- observed_rug=observed_rug,
903
- random_seed=3,
904
- backend="bokeh",
905
- show=False,
906
- )
907
- assert axes
908
-
909
-
910
- def test_plot_ppc_textsize(models):
911
- axes = plot_ppc(
912
- models.model_1,
913
- textsize=10,
914
- random_seed=3,
915
- backend="bokeh",
916
- show=False,
917
- )
918
- assert axes
919
-
920
-
921
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
922
- @pytest.mark.parametrize("jitter", [None, 0, 0.1, 1, 3])
923
- def test_plot_ppc_multichain(kind, jitter):
924
- data = from_dict(
925
- posterior_predictive={
926
- "x": np.random.randn(4, 100, 30),
927
- "y_hat": np.random.randn(4, 100, 3, 10),
928
- },
929
- observed_data={"x": np.random.randn(30), "y": np.random.randn(3, 10)},
930
- )
931
- axes = plot_ppc(
932
- data,
933
- kind=kind,
934
- data_pairs={"y": "y_hat"},
935
- jitter=jitter,
936
- random_seed=3,
937
- backend="bokeh",
938
- show=False,
939
- )
940
- assert np.all(axes)
941
-
942
-
943
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
944
- def test_plot_ppc_discrete(kind):
945
- data = from_dict(
946
- observed_data={"obs": np.random.randint(1, 100, 15)},
947
- posterior_predictive={"obs": np.random.randint(1, 300, (1, 20, 15))},
948
- )
949
-
950
- axes = plot_ppc(data, kind=kind, backend="bokeh", show=False)
951
- assert axes
952
-
953
-
954
- def test_plot_ppc_grid(models):
955
- axes = plot_ppc(models.model_1, kind="scatter", flatten=[], backend="bokeh", show=False)
956
- assert len(axes.ravel()) == 8
957
- axes = plot_ppc(
958
- models.model_1,
959
- kind="scatter",
960
- flatten=[],
961
- coords={"obs_dim": [1, 2, 3]},
962
- backend="bokeh",
963
- show=False,
964
- )
965
- assert len(axes.ravel()) == 3
966
- axes = plot_ppc(
967
- models.model_1,
968
- kind="scatter",
969
- flatten=["obs_dim"],
970
- coords={"obs_dim": [1, 2, 3]},
971
- backend="bokeh",
972
- show=False,
973
- )
974
- assert len(axes.ravel()) == 1
975
-
976
-
977
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
978
- def test_plot_ppc_bad(models, kind):
979
- data = from_dict(posterior={"mu": np.random.randn()})
980
- with pytest.raises(TypeError):
981
- plot_ppc(data, kind=kind, backend="bokeh", show=False)
982
- with pytest.raises(TypeError):
983
- plot_ppc(models.model_1, kind="bad_val", backend="bokeh", show=False)
984
- with pytest.raises(TypeError):
985
- plot_ppc(models.model_1, num_pp_samples="bad_val", backend="bokeh", show=False)
986
-
987
-
988
- @pytest.mark.parametrize("kind", ["kde", "cumulative", "scatter"])
989
- def test_plot_ppc_ax(models, kind):
990
- """Test ax argument of plot_ppc."""
991
- ax = bkp.figure()
992
- axes = plot_ppc(models.model_1, kind=kind, ax=ax, backend="bokeh", show=False)
993
- assert axes[0, 0] is ax
994
-
995
-
996
- @pytest.mark.parametrize(
997
- "kwargs",
998
- [
999
- {},
1000
- {"var_names": "mu"},
1001
- {"var_names": ("mu", "tau")},
1002
- {"rope": (-2, 2)},
1003
- {"rope": {"mu": [{"rope": (-2, 2)}], "theta": [{"school": "Choate", "rope": (2, 4)}]}},
1004
- {"point_estimate": "mode"},
1005
- {"point_estimate": "median"},
1006
- {"point_estimate": None},
1007
- {"hdi_prob": "hide", "legend_label": ""},
1008
- {"ref_val": 0},
1009
- {"ref_val": None},
1010
- {"ref_val": {"mu": [{"ref_val": 1}]}},
1011
- {"bins": None, "kind": "hist"},
1012
- {
1013
- "ref_val": {
1014
- "theta": [
1015
- # {"school": ["Choate", "Deerfield"], "ref_val": -1}, this is not working
1016
- {"school": "Lawrenceville", "ref_val": 3}
1017
- ]
1018
- }
1019
- },
1020
- ],
1021
- )
1022
- def test_plot_posterior(models, kwargs):
1023
- axes = plot_posterior(models.model_1, backend="bokeh", show=False, **kwargs)
1024
- assert axes.shape
1025
-
1026
-
1027
- @pytest.mark.parametrize("kwargs", [{}, {"point_estimate": "mode"}, {"bins": None, "kind": "hist"}])
1028
- def test_plot_posterior_discrete(discrete_model, kwargs):
1029
- axes = plot_posterior(discrete_model, backend="bokeh", show=False, **kwargs)
1030
- assert axes.shape
1031
-
1032
-
1033
- def test_plot_posterior_boolean():
1034
- data = np.random.choice(a=[False, True], size=(4, 100))
1035
- axes = plot_posterior(data, backend="bokeh", show=False)
1036
- assert axes.shape
1037
-
1038
-
1039
- def test_plot_posterior_bad_type():
1040
- with pytest.raises(TypeError):
1041
- plot_posterior(np.array(["a", "b", "c"]), backend="bokeh", show=False)
1042
-
1043
-
1044
- def test_plot_posterior_bad(models):
1045
- with pytest.raises(ValueError):
1046
- plot_posterior(models.model_1, backend="bokeh", show=False, rope="bad_value")
1047
- with pytest.raises(ValueError):
1048
- plot_posterior(models.model_1, ref_val="bad_value", backend="bokeh", show=False)
1049
- with pytest.raises(ValueError):
1050
- plot_posterior(models.model_1, point_estimate="bad_value", backend="bokeh", show=False)
1051
-
1052
-
1053
- @pytest.mark.parametrize("point_estimate", ("mode", "mean", "median"))
1054
- def test_plot_posterior_point_estimates(models, point_estimate):
1055
- axes = plot_posterior(
1056
- models.model_1,
1057
- var_names=("mu", "tau"),
1058
- point_estimate=point_estimate,
1059
- backend="bokeh",
1060
- show=False,
1061
- )
1062
- assert axes.shape == (1, 2)
1063
-
1064
-
1065
- def test_plot_posterior_skipna():
1066
- sample = np.linspace(0, 1)
1067
- sample[:10] = np.nan
1068
- plot_posterior({"a": sample}, backend="bokeh", show=False, skipna=True)
1069
- with pytest.raises(ValueError):
1070
- plot_posterior({"a": sample}, backend="bokeh", show=False, skipna=False)
1071
-
1072
-
1073
- @pytest.mark.parametrize(
1074
- "kwargs",
1075
- [
1076
- {},
1077
- {"var_names": "mu"},
1078
- {"var_names": ("mu", "tau"), "coords": {"school": [0, 1]}},
1079
- {"var_names": "mu", "ref_line": True},
1080
- {
1081
- "var_names": "mu",
1082
- "ref_line_kwargs": {"line_width": 2, "line_color": "red"},
1083
- "bar_kwargs": {"width": 50},
1084
- },
1085
- {"var_names": "mu", "ref_line": False},
1086
- {"var_names": "mu", "kind": "vlines"},
1087
- {
1088
- "var_names": "mu",
1089
- "kind": "vlines",
1090
- "vlines_kwargs": {"line_width": 0},
1091
- "marker_vlines_kwargs": {"radius": 20},
1092
- },
1093
- ],
1094
- )
1095
- def test_plot_rank(models, kwargs):
1096
- axes = plot_rank(models.model_1, backend="bokeh", show=False, **kwargs)
1097
- assert axes.shape
1098
-
1099
-
1100
- def test_plot_dist_comparison_warn(models):
1101
- with pytest.raises(NotImplementedError, match="The bokeh backend.+Use matplotlib backend."):
1102
- plot_dist_comparison(models.model_1, backend="bokeh")
1103
-
1104
-
1105
- @pytest.mark.parametrize(
1106
- "kwargs",
1107
- [
1108
- {},
1109
- {"reference": "analytical"},
1110
- {"kind": "p_value"},
1111
- {"kind": "t_stat", "t_stat": "std"},
1112
- {"kind": "t_stat", "t_stat": 0.5, "bpv": True},
1113
- ],
1114
- )
1115
- def test_plot_bpv(models, kwargs):
1116
- axes = plot_bpv(models.model_1, backend="bokeh", show=False, **kwargs)
1117
- assert axes.shape
1118
-
1119
-
1120
- def test_plot_bpv_discrete():
1121
- fake_obs = {"a": np.random.poisson(2.5, 100)}
1122
- fake_pp = {"a": np.random.poisson(2.5, (1, 10, 100))}
1123
- fake_model = from_dict(posterior_predictive=fake_pp, observed_data=fake_obs)
1124
- axes = plot_bpv(
1125
- fake_model,
1126
- backend="bokeh",
1127
- show=False,
1128
- )
1129
- assert axes.shape
1130
-
1131
-
1132
- @pytest.mark.parametrize(
1133
- "kwargs",
1134
- [
1135
- {},
1136
- {
1137
- "binwidth": 0.5,
1138
- "stackratio": 2,
1139
- "nquantiles": 20,
1140
- },
1141
- {"point_interval": True},
1142
- {
1143
- "point_interval": True,
1144
- "dotsize": 1.2,
1145
- "point_estimate": "median",
1146
- "plot_kwargs": {"color": "grey"},
1147
- },
1148
- {
1149
- "point_interval": True,
1150
- "plot_kwargs": {"color": "grey"},
1151
- "nquantiles": 100,
1152
- "hdi_prob": 0.95,
1153
- "intervalcolor": "green",
1154
- },
1155
- {
1156
- "point_interval": True,
1157
- "plot_kwargs": {"color": "grey"},
1158
- "quartiles": False,
1159
- "linewidth": 2,
1160
- },
1161
- ],
1162
- )
1163
- def test_plot_dot(continuous_model, kwargs):
1164
- data = continuous_model["x"]
1165
- ax = plot_dot(data, **kwargs, backend="bokeh", show=False)
1166
- assert ax
1167
-
1168
-
1169
- @pytest.mark.parametrize(
1170
- "kwargs",
1171
- [
1172
- {"rotated": True},
1173
- {
1174
- "point_interval": True,
1175
- "rotated": True,
1176
- "dotcolor": "grey",
1177
- "binwidth": 0.5,
1178
- },
1179
- {
1180
- "rotated": True,
1181
- "point_interval": True,
1182
- "plot_kwargs": {"color": "grey"},
1183
- "nquantiles": 100,
1184
- "dotsize": 0.8,
1185
- "hdi_prob": 0.95,
1186
- "intervalcolor": "green",
1187
- },
1188
- ],
1189
- )
1190
- def test_plot_dot_rotated(continuous_model, kwargs):
1191
- data = continuous_model["x"]
1192
- ax = plot_dot(data, **kwargs, backend="bokeh", show=False)
1193
- assert ax
1194
-
1195
-
1196
- @pytest.mark.parametrize(
1197
- "kwargs",
1198
- [
1199
- {},
1200
- {"y_hat": "bad_name"},
1201
- {"x": "x1"},
1202
- {"x": ("x1", "x2")},
1203
- {
1204
- "x": ("x1", "x2"),
1205
- "y_kwargs": {"fill_color": "blue"},
1206
- "y_hat_plot_kwargs": {"fill_color": "orange"},
1207
- "legend": True,
1208
- },
1209
- {"x": ("x1", "x2"), "y_model_plot_kwargs": {"line_color": "red"}},
1210
- {
1211
- "x": ("x1", "x2"),
1212
- "kind_pp": "hdi",
1213
- "kind_model": "hdi",
1214
- "y_model_fill_kwargs": {"color": "red"},
1215
- "y_hat_fill_kwargs": {"color": "cyan"},
1216
- },
1217
- ],
1218
- )
1219
- def test_plot_lm_1d(models, kwargs):
1220
- """Test functionality for 1D data."""
1221
- idata = models.model_1
1222
- if "constant_data" not in idata.groups():
1223
- y = idata.observed_data["y"]
1224
- x1data = y.coords[y.dims[0]]
1225
- idata.add_groups({"constant_data": {"_": x1data}})
1226
- idata.constant_data["x1"] = x1data
1227
- idata.constant_data["x2"] = x1data
1228
-
1229
- axes = plot_lm(
1230
- idata=idata, y="y", y_model="eta", backend="bokeh", xjitter=True, show=False, **kwargs
1231
- )
1232
- assert np.all(axes)
1233
-
1234
-
1235
- def test_plot_lm_multidim(multidim_models):
1236
- """Test functionality for multidimentional data."""
1237
- idata = multidim_models.model_1
1238
- axes = plot_lm(idata=idata, y="y", plot_dim="dim1", show=False, backend="bokeh")
1239
- assert np.any(axes)
1240
-
1241
-
1242
- def test_plot_lm_list():
1243
- """Test the plots when input data is list or ndarray."""
1244
- y = [1, 2, 3, 4, 5]
1245
- assert plot_lm(y=y, x=np.arange(len(y)), show=False, backend="bokeh")
1246
-
1247
-
1248
- def generate_lm_1d_data():
1249
- rng = np.random.default_rng()
1250
- return from_dict(
1251
- observed_data={"y": rng.normal(size=7)},
1252
- posterior_predictive={"y": rng.normal(size=(4, 1000, 7)) / 2},
1253
- posterior={"y_model": rng.normal(size=(4, 1000, 7))},
1254
- dims={"y": ["dim1"]},
1255
- coords={"dim1": range(7)},
1256
- )
1257
-
1258
-
1259
- def generate_lm_2d_data():
1260
- rng = np.random.default_rng()
1261
- return from_dict(
1262
- observed_data={"y": rng.normal(size=(5, 7))},
1263
- posterior_predictive={"y": rng.normal(size=(4, 1000, 5, 7)) / 2},
1264
- posterior={"y_model": rng.normal(size=(4, 1000, 5, 7))},
1265
- dims={"y": ["dim1", "dim2"]},
1266
- coords={"dim1": range(5), "dim2": range(7)},
1267
- )
1268
-
1269
-
1270
- @pytest.mark.parametrize("data", ("1d", "2d"))
1271
- @pytest.mark.parametrize("kind", ("lines", "hdi"))
1272
- @pytest.mark.parametrize("use_y_model", (True, False))
1273
- def test_plot_lm(data, kind, use_y_model):
1274
- if data == "1d":
1275
- idata = generate_lm_1d_data()
1276
- else:
1277
- idata = generate_lm_2d_data()
1278
-
1279
- kwargs = {"idata": idata, "y": "y", "kind_model": kind, "backend": "bokeh", "show": False}
1280
- if data == "2d":
1281
- kwargs["plot_dim"] = "dim1"
1282
- if use_y_model:
1283
- kwargs["y_model"] = "y_model"
1284
- if kind == "lines":
1285
- kwargs["num_samples"] = 50
1286
-
1287
- ax = plot_lm(**kwargs)
1288
- assert ax is not None