arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,260 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name
2
- import numpy as np
3
- import packaging
4
- import pytest
5
-
6
- from ...data.io_pyro import from_pyro # pylint: disable=wrong-import-position
7
- from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
8
- chains,
9
- check_multiple_attrs,
10
- draws,
11
- eight_schools_params,
12
- importorskip,
13
- load_cached_models,
14
- )
15
-
16
- # Skip all tests if pyro or pytorch not installed
17
- torch = importorskip("torch")
18
- pyro = importorskip("pyro")
19
- Predictive = pyro.infer.Predictive
20
- dist = pyro.distributions
21
-
22
-
23
- class TestDataPyro:
24
- @pytest.fixture(scope="class")
25
- def data(self, eight_schools_params, draws, chains):
26
- class Data:
27
- obj = load_cached_models(eight_schools_params, draws, chains, "pyro")["pyro"]
28
-
29
- return Data
30
-
31
- @pytest.fixture(scope="class")
32
- def predictions_params(self):
33
- """Predictions data for eight schools."""
34
- return {
35
- "J": 8,
36
- "sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]),
37
- }
38
-
39
- @pytest.fixture(scope="class")
40
- def predictions_data(self, data, predictions_params):
41
- """Generate predictions for predictions_params"""
42
- posterior_samples = data.obj.get_samples()
43
- model = data.obj.kernel.model
44
- predictions = Predictive(model, posterior_samples)(
45
- predictions_params["J"], torch.from_numpy(predictions_params["sigma"]).float()
46
- )
47
- return predictions
48
-
49
- def get_inference_data(self, data, eight_schools_params, predictions_data):
50
- posterior_samples = data.obj.get_samples()
51
- model = data.obj.kernel.model
52
- posterior_predictive = Predictive(model, posterior_samples)(
53
- eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
54
- )
55
- prior = Predictive(model, num_samples=500)(
56
- eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
57
- )
58
- predictions = predictions_data
59
- return from_pyro(
60
- posterior=data.obj,
61
- prior=prior,
62
- posterior_predictive=posterior_predictive,
63
- predictions=predictions,
64
- coords={
65
- "school": np.arange(eight_schools_params["J"]),
66
- "school_pred": np.arange(eight_schools_params["J"]),
67
- },
68
- dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]},
69
- pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]},
70
- )
71
-
72
- def test_inference_data(self, data, eight_schools_params, predictions_data):
73
- inference_data = self.get_inference_data(data, eight_schools_params, predictions_data)
74
- test_dict = {
75
- "posterior": ["mu", "tau", "eta"],
76
- "sample_stats": ["diverging"],
77
- "posterior_predictive": ["obs"],
78
- "predictions": ["obs"],
79
- "prior": ["mu", "tau", "eta"],
80
- "prior_predictive": ["obs"],
81
- }
82
- fails = check_multiple_attrs(test_dict, inference_data)
83
- assert not fails
84
-
85
- # test dims
86
- dims = inference_data.posterior_predictive.sizes["school"]
87
- pred_dims = inference_data.predictions.sizes["school_pred"]
88
- assert dims == 8
89
- assert pred_dims == 8
90
-
91
- @pytest.mark.skipif(
92
- packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
93
- reason="requires pyro 1.0.0 or higher",
94
- )
95
- def test_inference_data_has_log_likelihood_and_observed_data(self, data):
96
- idata = from_pyro(data.obj)
97
- test_dict = {"log_likelihood": ["obs"], "observed_data": ["obs"]}
98
- fails = check_multiple_attrs(test_dict, idata)
99
- assert not fails
100
-
101
- def test_inference_data_no_posterior(
102
- self, data, eight_schools_params, predictions_data, predictions_params
103
- ):
104
- posterior_samples = data.obj.get_samples()
105
- model = data.obj.kernel.model
106
- posterior_predictive = Predictive(model, posterior_samples)(
107
- eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
108
- )
109
- prior = Predictive(model, num_samples=500)(
110
- eight_schools_params["J"], torch.from_numpy(eight_schools_params["sigma"]).float()
111
- )
112
- predictions = predictions_data
113
- constant_data = {"J": 8, "sigma": eight_schools_params["sigma"]}
114
- predictions_constant_data = predictions_params
115
- # only prior
116
- inference_data = from_pyro(prior=prior)
117
- test_dict = {"prior": ["mu", "tau", "eta"]}
118
- fails = check_multiple_attrs(test_dict, inference_data)
119
- assert not fails, f"only prior: {fails}"
120
- # only posterior_predictive
121
- inference_data = from_pyro(posterior_predictive=posterior_predictive)
122
- test_dict = {"posterior_predictive": ["obs"]}
123
- fails = check_multiple_attrs(test_dict, inference_data)
124
- assert not fails, f"only posterior_predictive: {fails}"
125
- # only predictions
126
- inference_data = from_pyro(predictions=predictions)
127
- test_dict = {"predictions": ["obs"]}
128
- fails = check_multiple_attrs(test_dict, inference_data)
129
- assert not fails, f"only predictions: {fails}"
130
- # only constant_data
131
- inference_data = from_pyro(constant_data=constant_data)
132
- test_dict = {"constant_data": ["J", "sigma"]}
133
- fails = check_multiple_attrs(test_dict, inference_data)
134
- assert not fails, f"only constant_data: {fails}"
135
- # only predictions_constant_data
136
- inference_data = from_pyro(predictions_constant_data=predictions_constant_data)
137
- test_dict = {"predictions_constant_data": ["J", "sigma"]}
138
- fails = check_multiple_attrs(test_dict, inference_data)
139
- assert not fails, f"only predictions_constant_data: {fails}"
140
- # prior and posterior_predictive
141
- idata = from_pyro(
142
- prior=prior,
143
- posterior_predictive=posterior_predictive,
144
- coords={"school": np.arange(eight_schools_params["J"])},
145
- dims={"theta": ["school"], "eta": ["school"]},
146
- )
147
- test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
148
- fails = check_multiple_attrs(test_dict, idata)
149
- assert not fails, f"prior and posterior_predictive: {fails}"
150
-
151
- def test_inference_data_only_posterior(self, data):
152
- idata = from_pyro(data.obj)
153
- test_dict = {"posterior": ["mu", "tau", "eta"], "sample_stats": ["diverging"]}
154
- fails = check_multiple_attrs(test_dict, idata)
155
- assert not fails
156
-
157
- @pytest.mark.skipif(
158
- packaging.version.parse(pyro.__version__) < packaging.version.parse("1.0.0"),
159
- reason="requires pyro 1.0.0 or higher",
160
- )
161
- def test_inference_data_only_posterior_has_log_likelihood(self, data):
162
- idata = from_pyro(data.obj)
163
- test_dict = {"log_likelihood": ["obs"]}
164
- fails = check_multiple_attrs(test_dict, idata)
165
- assert not fails
166
-
167
- def test_multiple_observed_rv(self):
168
- y1 = torch.randn(10)
169
- y2 = torch.randn(10)
170
-
171
- def model_example_multiple_obs(y1=None, y2=None):
172
- x = pyro.sample("x", dist.Normal(1, 3))
173
- pyro.sample("y1", dist.Normal(x, 1), obs=y1)
174
- pyro.sample("y2", dist.Normal(x, 1), obs=y2)
175
-
176
- nuts_kernel = pyro.infer.NUTS(model_example_multiple_obs)
177
- mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
178
- mcmc.run(y1=y1, y2=y2)
179
- inference_data = from_pyro(mcmc)
180
- test_dict = {
181
- "posterior": ["x"],
182
- "sample_stats": ["diverging"],
183
- "log_likelihood": ["y1", "y2"],
184
- "observed_data": ["y1", "y2"],
185
- }
186
- fails = check_multiple_attrs(test_dict, inference_data)
187
- assert not fails
188
- assert not hasattr(inference_data.sample_stats, "log_likelihood")
189
-
190
- def test_inference_data_constant_data(self):
191
- x1 = 10
192
- x2 = 12
193
- y1 = torch.randn(10)
194
-
195
- def model_constant_data(x, y1=None):
196
- _x = pyro.sample("x", dist.Normal(1, 3))
197
- pyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)
198
-
199
- nuts_kernel = pyro.infer.NUTS(model_constant_data)
200
- mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
201
- mcmc.run(x=x1, y1=y1)
202
- posterior = mcmc.get_samples()
203
- posterior_predictive = Predictive(model_constant_data, posterior)(x1)
204
- predictions = Predictive(model_constant_data, posterior)(x2)
205
- inference_data = from_pyro(
206
- mcmc,
207
- posterior_predictive=posterior_predictive,
208
- predictions=predictions,
209
- constant_data={"x1": x1},
210
- predictions_constant_data={"x2": x2},
211
- )
212
- test_dict = {
213
- "posterior": ["x"],
214
- "posterior_predictive": ["y1"],
215
- "sample_stats": ["diverging"],
216
- "log_likelihood": ["y1"],
217
- "predictions": ["y1"],
218
- "observed_data": ["y1"],
219
- "constant_data": ["x1"],
220
- "predictions_constant_data": ["x2"],
221
- }
222
- fails = check_multiple_attrs(test_dict, inference_data)
223
- assert not fails
224
-
225
- def test_inference_data_num_chains(self, predictions_data, chains):
226
- predictions = predictions_data
227
- inference_data = from_pyro(predictions=predictions, num_chains=chains)
228
- nchains = inference_data.predictions.sizes["chain"]
229
- assert nchains == chains
230
-
231
- @pytest.mark.parametrize("log_likelihood", [True, False])
232
- def test_log_likelihood(self, log_likelihood):
233
- """Test behaviour when log likelihood cannot be retrieved.
234
-
235
- If log_likelihood=True there is a warning to say log_likelihood group is skipped,
236
- if log_likelihood=False there is no warning and log_likelihood is skipped.
237
- """
238
- x = torch.randn((10, 2))
239
- y = torch.randn(10)
240
-
241
- def model_constant_data(x, y=None):
242
- beta = pyro.sample("beta", dist.Normal(torch.ones(2), 3))
243
- pyro.sample("y", dist.Normal(x.matmul(beta), 1), obs=y)
244
-
245
- nuts_kernel = pyro.infer.NUTS(model_constant_data)
246
- mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=10)
247
- mcmc.run(x=x, y=y)
248
- if log_likelihood:
249
- with pytest.warns(UserWarning, match="Could not get vectorized trace"):
250
- inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
251
- else:
252
- inference_data = from_pyro(mcmc, log_likelihood=log_likelihood)
253
- test_dict = {
254
- "posterior": ["beta"],
255
- "sample_stats": ["diverging"],
256
- "~log_likelihood": [""],
257
- "observed_data": ["y"],
258
- }
259
- fails = check_multiple_attrs(test_dict, inference_data)
260
- assert not fails
@@ -1,307 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-function-args
2
- import importlib
3
- from collections import OrderedDict
4
- import os
5
-
6
- import numpy as np
7
- import pytest
8
-
9
- from ... import from_pystan
10
-
11
- from ...data.io_pystan import get_draws, get_draws_stan3 # pylint: disable=unused-import
12
- from ..helpers import ( # pylint: disable=unused-import
13
- chains,
14
- check_multiple_attrs,
15
- draws,
16
- eight_schools_params,
17
- importorskip,
18
- load_cached_models,
19
- pystan_version,
20
- )
21
-
22
- # Check if either pystan or pystan3 is installed
23
- pystan_installed = (importlib.util.find_spec("pystan") is not None) or (
24
- importlib.util.find_spec("stan") is not None
25
- )
26
-
27
-
28
- @pytest.mark.skipif(
29
- not (pystan_installed or "ARVIZ_REQUIRE_ALL_DEPS" in os.environ),
30
- reason="test requires pystan/pystan3 which is not installed",
31
- )
32
- class TestDataPyStan:
33
- @pytest.fixture(scope="class")
34
- def data(self, eight_schools_params, draws, chains):
35
- class Data:
36
- model, obj = load_cached_models(eight_schools_params, draws, chains, "pystan")["pystan"]
37
-
38
- return Data
39
-
40
- def get_inference_data(self, data, eight_schools_params):
41
- """vars as str."""
42
- return from_pystan(
43
- posterior=data.obj,
44
- posterior_predictive="y_hat",
45
- predictions="y_hat", # wrong, but fine for testing
46
- prior=data.obj,
47
- prior_predictive="y_hat",
48
- observed_data="y",
49
- constant_data="sigma",
50
- predictions_constant_data="sigma", # wrong, but fine for testing
51
- log_likelihood={"y": "log_lik"},
52
- coords={"school": np.arange(eight_schools_params["J"])},
53
- dims={
54
- "theta": ["school"],
55
- "y": ["school"],
56
- "sigma": ["school"],
57
- "y_hat": ["school"],
58
- "eta": ["school"],
59
- },
60
- posterior_model=data.model,
61
- prior_model=data.model,
62
- )
63
-
64
- def get_inference_data2(self, data, eight_schools_params):
65
- """vars as lists."""
66
- return from_pystan(
67
- posterior=data.obj,
68
- posterior_predictive=["y_hat"],
69
- predictions=["y_hat"], # wrong, but fine for testing
70
- prior=data.obj,
71
- prior_predictive=["y_hat"],
72
- observed_data=["y"],
73
- log_likelihood="log_lik",
74
- coords={
75
- "school": np.arange(eight_schools_params["J"]),
76
- "log_likelihood_dim": np.arange(eight_schools_params["J"]),
77
- },
78
- dims={
79
- "theta": ["school"],
80
- "y": ["school"],
81
- "y_hat": ["school"],
82
- "eta": ["school"],
83
- "log_lik": ["log_likelihood_dim"],
84
- },
85
- posterior_model=data.model,
86
- prior_model=data.model,
87
- )
88
-
89
- def get_inference_data3(self, data, eight_schools_params):
90
- """multiple vars as lists."""
91
- return from_pystan(
92
- posterior=data.obj,
93
- posterior_predictive=["y_hat", "log_lik"], # wrong, but fine for testing
94
- predictions=["y_hat", "log_lik"], # wrong, but fine for testing
95
- prior=data.obj,
96
- prior_predictive=["y_hat", "log_lik"], # wrong, but fine for testing
97
- constant_data=["sigma", "y"], # wrong, but fine for testing
98
- predictions_constant_data=["sigma", "y"], # wrong, but fine for testing
99
- coords={"school": np.arange(eight_schools_params["J"])},
100
- dims={
101
- "theta": ["school"],
102
- "y": ["school"],
103
- "sigma": ["school"],
104
- "y_hat": ["school"],
105
- "eta": ["school"],
106
- },
107
- posterior_model=data.model,
108
- prior_model=data.model,
109
- )
110
-
111
- def get_inference_data4(self, data):
112
- """minimal input."""
113
- return from_pystan(
114
- posterior=data.obj,
115
- posterior_predictive=None,
116
- prior=data.obj,
117
- prior_predictive=None,
118
- coords=None,
119
- dims=None,
120
- posterior_model=data.model,
121
- log_likelihood=[],
122
- prior_model=data.model,
123
- save_warmup=True,
124
- )
125
-
126
- def get_inference_data5(self, data):
127
- """minimal input."""
128
- return from_pystan(
129
- posterior=data.obj,
130
- posterior_predictive=None,
131
- prior=data.obj,
132
- prior_predictive=None,
133
- coords=None,
134
- dims=None,
135
- posterior_model=data.model,
136
- log_likelihood=False,
137
- prior_model=data.model,
138
- save_warmup=True,
139
- dtypes={"eta": int},
140
- )
141
-
142
- def test_sampler_stats(self, data, eight_schools_params):
143
- inference_data = self.get_inference_data(data, eight_schools_params)
144
- test_dict = {"sample_stats": ["diverging"]}
145
- fails = check_multiple_attrs(test_dict, inference_data)
146
- assert not fails
147
-
148
- def test_inference_data(self, data, eight_schools_params):
149
- inference_data1 = self.get_inference_data(data, eight_schools_params)
150
- inference_data2 = self.get_inference_data2(data, eight_schools_params)
151
- inference_data3 = self.get_inference_data3(data, eight_schools_params)
152
- inference_data4 = self.get_inference_data4(data)
153
- inference_data5 = self.get_inference_data5(data)
154
- # inference_data 1
155
- test_dict = {
156
- "posterior": ["theta", "~log_lik"],
157
- "posterior_predictive": ["y_hat"],
158
- "predictions": ["y_hat"],
159
- "observed_data": ["y"],
160
- "constant_data": ["sigma"],
161
- "predictions_constant_data": ["sigma"],
162
- "sample_stats": ["diverging", "lp"],
163
- "log_likelihood": ["y", "~log_lik"],
164
- "prior": ["theta"],
165
- }
166
- fails = check_multiple_attrs(test_dict, inference_data1)
167
- assert not fails
168
- # inference_data 2
169
- test_dict = {
170
- "posterior_predictive": ["y_hat"],
171
- "predictions": ["y_hat"],
172
- "observed_data": ["y"],
173
- "sample_stats_prior": ["diverging"],
174
- "sample_stats": ["diverging", "lp"],
175
- "log_likelihood": ["log_lik"],
176
- "prior_predictive": ["y_hat"],
177
- }
178
- fails = check_multiple_attrs(test_dict, inference_data2)
179
- assert not fails
180
- assert any(
181
- item in inference_data2.posterior.attrs for item in ["stan_code", "program_code"]
182
- )
183
- assert any(
184
- item in inference_data2.sample_stats.attrs for item in ["stan_code", "program_code"]
185
- )
186
- # inference_data 3
187
- test_dict = {
188
- "posterior_predictive": ["y_hat", "log_lik"],
189
- "predictions": ["y_hat", "log_lik"],
190
- "constant_data": ["sigma", "y"],
191
- "predictions_constant_data": ["sigma", "y"],
192
- "sample_stats_prior": ["diverging"],
193
- "sample_stats": ["diverging", "lp"],
194
- "log_likelihood": ["log_lik"],
195
- "prior_predictive": ["y_hat", "log_lik"],
196
- }
197
- fails = check_multiple_attrs(test_dict, inference_data3)
198
- assert not fails
199
- # inference_data 4
200
- test_dict = {
201
- "posterior": ["theta"],
202
- "prior": ["theta"],
203
- "sample_stats": ["diverging", "lp"],
204
- "~log_likelihood": [""],
205
- "warmup_posterior": ["theta"],
206
- "warmup_sample_stats": ["diverging", "lp"],
207
- }
208
- fails = check_multiple_attrs(test_dict, inference_data4)
209
- assert not fails
210
- # inference_data 5
211
- test_dict = {
212
- "posterior": ["theta"],
213
- "prior": ["theta"],
214
- "sample_stats": ["diverging", "lp"],
215
- "~log_likelihood": [""],
216
- "warmup_posterior": ["theta"],
217
- "warmup_sample_stats": ["diverging", "lp"],
218
- }
219
- fails = check_multiple_attrs(test_dict, inference_data5)
220
- assert not fails
221
- assert inference_data5.posterior.eta.dtype.kind == "i"
222
-
223
- def test_invalid_fit(self, data):
224
- if pystan_version() == 2:
225
- model = data.model
226
- model_data = {
227
- "J": 8,
228
- "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
229
- "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
230
- }
231
- fit_test_grad = model.sampling(
232
- data=model_data, test_grad=True, check_hmc_diagnostics=False
233
- )
234
- with pytest.raises(AttributeError):
235
- _ = from_pystan(posterior=fit_test_grad)
236
- fit = model.sampling(data=model_data, iter=100, chains=1, check_hmc_diagnostics=False)
237
- del fit.sim["samples"]
238
- with pytest.raises(AttributeError):
239
- _ = from_pystan(posterior=fit)
240
-
241
- def test_empty_parameter(self):
242
- model_code = """
243
- parameters {
244
- real y;
245
- vector[3] x;
246
- vector[0] a;
247
- vector[2] z;
248
- }
249
- model {
250
- y ~ normal(0,1);
251
- }
252
- """
253
- if pystan_version() == 2:
254
- from pystan import StanModel # pylint: disable=import-error
255
-
256
- model = StanModel(model_code=model_code)
257
- fit = model.sampling(iter=500, chains=2, check_hmc_diagnostics=False)
258
- else:
259
- import stan # pylint: disable=import-error
260
-
261
- model = stan.build(model_code)
262
- fit = model.sample(num_samples=500, num_chains=2)
263
-
264
- posterior = from_pystan(posterior=fit)
265
- test_dict = {"posterior": ["y", "x", "z", "~a"], "sample_stats": ["diverging"]}
266
- fails = check_multiple_attrs(test_dict, posterior)
267
- assert not fails
268
-
269
- def test_get_draws(self, data):
270
- fit = data.obj
271
- if pystan_version() == 2:
272
- draws, _ = get_draws(fit, variables=["theta", "theta"])
273
- else:
274
- draws, _ = get_draws_stan3(fit, variables=["theta", "theta"])
275
- assert draws.get("theta") is not None
276
-
277
- @pytest.mark.skipif(pystan_version() != 2, reason="PyStan 2.x required")
278
- def test_index_order(self, data, eight_schools_params):
279
- """Test 0-indexed data."""
280
- # Skip test if pystan not installed
281
- pystan = importorskip("pystan") # pylint: disable=import-error
282
-
283
- fit = data.model.sampling(data=eight_schools_params)
284
- if pystan.__version__ >= "2.18":
285
- # make 1-indexed to 0-indexed
286
- for holder in fit.sim["samples"]:
287
- new_chains = OrderedDict()
288
- for i, (key, values) in enumerate(holder.chains.items()):
289
- if "[" in key:
290
- name, *shape = key.replace("]", "").split("[")
291
- shape = [str(int(item) - 1) for items in shape for item in items.split(",")]
292
- key = f"{name}[{','.join(shape)}]"
293
- new_chains[key] = np.full_like(values, fill_value=float(i))
294
- setattr(holder, "chains", new_chains)
295
- fit.sim["fnames_oi"] = list(fit.sim["samples"][0].chains.keys())
296
- idata = from_pystan(posterior=fit)
297
- assert idata is not None
298
- for j, fpar in enumerate(fit.sim["fnames_oi"]):
299
- par, *shape = fpar.replace("]", "").split("[")
300
- if par in {"lp__", "log_lik"}:
301
- continue
302
- assert hasattr(idata.posterior, par), (par, list(idata.posterior.data_vars))
303
- if shape:
304
- shape = [slice(None), slice(None)] + list(map(int, shape))
305
- assert idata.posterior[par][tuple(shape)].values.mean() == float(j)
306
- else:
307
- assert idata.posterior[par].values.mean() == float(j)