arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,434 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name, too-many-public-methods
2
- from collections import namedtuple
3
- import numpy as np
4
- import pytest
5
-
6
- from ...data.io_numpyro import from_numpyro # pylint: disable=wrong-import-position
7
- from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
8
- chains,
9
- check_multiple_attrs,
10
- draws,
11
- eight_schools_params,
12
- importorskip,
13
- load_cached_models,
14
- )
15
-
16
- # Skip all tests if jax or numpyro not installed
17
- jax = importorskip("jax")
18
- PRNGKey = jax.random.PRNGKey
19
- numpyro = importorskip("numpyro")
20
- Predictive = numpyro.infer.Predictive
21
-
22
-
23
- class TestDataNumPyro:
24
- @pytest.fixture(scope="class")
25
- def data(self, eight_schools_params, draws, chains):
26
- class Data:
27
- obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"]
28
-
29
- return Data
30
-
31
- @pytest.fixture(scope="class")
32
- def predictions_params(self):
33
- """Predictions data for eight schools."""
34
- return {
35
- "J": 8,
36
- "sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]),
37
- }
38
-
39
- @pytest.fixture(scope="class")
40
- def predictions_data(self, data, predictions_params):
41
- """Generate predictions for predictions_params"""
42
- posterior_samples = data.obj.get_samples()
43
- model = data.obj.sampler.model
44
- predictions = Predictive(model, posterior_samples)(
45
- PRNGKey(2), predictions_params["J"], predictions_params["sigma"]
46
- )
47
- return predictions
48
-
49
- def get_inference_data(
50
- self, data, eight_schools_params, predictions_data, predictions_params, infer_dims=False
51
- ):
52
- posterior_samples = data.obj.get_samples()
53
- model = data.obj.sampler.model
54
- posterior_predictive = Predictive(model, posterior_samples)(
55
- PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
56
- )
57
- prior = Predictive(model, num_samples=500)(
58
- PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
59
- )
60
- dims = {"theta": ["school"], "eta": ["school"], "obs": ["school"]}
61
- pred_dims = {"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}
62
- if infer_dims:
63
- dims = None
64
- pred_dims = None
65
-
66
- predictions = predictions_data
67
- return from_numpyro(
68
- posterior=data.obj,
69
- prior=prior,
70
- posterior_predictive=posterior_predictive,
71
- predictions=predictions,
72
- coords={
73
- "school": np.arange(eight_schools_params["J"]),
74
- "school_pred": np.arange(predictions_params["J"]),
75
- },
76
- dims=dims,
77
- pred_dims=pred_dims,
78
- )
79
-
80
- def test_inference_data_namedtuple(self, data):
81
- samples = data.obj.get_samples()
82
- Samples = namedtuple("Samples", samples)
83
- data_namedtuple = Samples(**samples)
84
- _old_fn = data.obj.get_samples
85
- data.obj.get_samples = lambda *args, **kwargs: data_namedtuple
86
- inference_data = from_numpyro(
87
- posterior=data.obj,
88
- dims={}, # This mock test needs to turn off autodims like so or mock group_by_chain
89
- )
90
- assert isinstance(data.obj.get_samples(), Samples)
91
- data.obj.get_samples = _old_fn
92
- for key in samples:
93
- assert key in inference_data.posterior
94
-
95
- def test_inference_data(self, data, eight_schools_params, predictions_data, predictions_params):
96
- inference_data = self.get_inference_data(
97
- data, eight_schools_params, predictions_data, predictions_params
98
- )
99
- test_dict = {
100
- "posterior": ["mu", "tau", "eta"],
101
- "sample_stats": ["diverging"],
102
- "log_likelihood": ["obs"],
103
- "posterior_predictive": ["obs"],
104
- "predictions": ["obs"],
105
- "prior": ["mu", "tau", "eta"],
106
- "prior_predictive": ["obs"],
107
- "observed_data": ["obs"],
108
- }
109
- fails = check_multiple_attrs(test_dict, inference_data)
110
- assert not fails
111
-
112
- # test dims
113
- dims = inference_data.posterior_predictive.sizes["school"]
114
- pred_dims = inference_data.predictions.sizes["school_pred"]
115
- assert dims == 8
116
- assert pred_dims == 8
117
-
118
- def test_inference_data_no_posterior(
119
- self, data, eight_schools_params, predictions_data, predictions_params
120
- ):
121
- posterior_samples = data.obj.get_samples()
122
- model = data.obj.sampler.model
123
- posterior_predictive = Predictive(model, posterior_samples)(
124
- PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"]
125
- )
126
- prior = Predictive(model, num_samples=500)(
127
- PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"]
128
- )
129
- predictions = predictions_data
130
- constant_data = {"J": 8, "sigma": eight_schools_params["sigma"]}
131
- predictions_constant_data = predictions_params
132
- # only prior
133
- inference_data = from_numpyro(prior=prior)
134
- test_dict = {"prior": ["mu", "tau", "eta"]}
135
- fails = check_multiple_attrs(test_dict, inference_data)
136
- assert not fails, f"only prior: {fails}"
137
- # only posterior_predictive
138
- inference_data = from_numpyro(posterior_predictive=posterior_predictive)
139
- test_dict = {"posterior_predictive": ["obs"]}
140
- fails = check_multiple_attrs(test_dict, inference_data)
141
- assert not fails, f"only posterior_predictive: {fails}"
142
- # only predictions
143
- inference_data = from_numpyro(predictions=predictions)
144
- test_dict = {"predictions": ["obs"]}
145
- fails = check_multiple_attrs(test_dict, inference_data)
146
- assert not fails, f"only predictions: {fails}"
147
- # only constant_data
148
- inference_data = from_numpyro(constant_data=constant_data)
149
- test_dict = {"constant_data": ["J", "sigma"]}
150
- fails = check_multiple_attrs(test_dict, inference_data)
151
- assert not fails, f"only constant_data: {fails}"
152
- # only predictions_constant_data
153
- inference_data = from_numpyro(predictions_constant_data=predictions_constant_data)
154
- test_dict = {"predictions_constant_data": ["J", "sigma"]}
155
- fails = check_multiple_attrs(test_dict, inference_data)
156
- assert not fails, f"only predictions_constant_data: {fails}"
157
- # prior and posterior_predictive
158
- idata = from_numpyro(
159
- prior=prior,
160
- posterior_predictive=posterior_predictive,
161
- coords={"school": np.arange(eight_schools_params["J"])},
162
- dims={"theta": ["school"], "eta": ["school"]},
163
- )
164
- test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]}
165
- fails = check_multiple_attrs(test_dict, idata)
166
- assert not fails, f"prior and posterior_predictive: {fails}"
167
-
168
- def test_inference_data_only_posterior(self, data):
169
- idata = from_numpyro(data.obj)
170
- test_dict = {
171
- "posterior": ["mu", "tau", "eta"],
172
- "sample_stats": ["diverging"],
173
- "log_likelihood": ["obs"],
174
- }
175
- fails = check_multiple_attrs(test_dict, idata)
176
- assert not fails
177
-
178
- def test_multiple_observed_rv(self):
179
- import numpyro
180
- import numpyro.distributions as dist
181
- from numpyro.infer import MCMC, NUTS
182
-
183
- y1 = np.random.randn(10)
184
- y2 = np.random.randn(100)
185
-
186
- def model_example_multiple_obs(y1=None, y2=None):
187
- x = numpyro.sample("x", dist.Normal(1, 3))
188
- numpyro.sample("y1", dist.Normal(x, 1), obs=y1)
189
- numpyro.sample("y2", dist.Normal(x, 1), obs=y2)
190
-
191
- nuts_kernel = NUTS(model_example_multiple_obs)
192
- mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
193
- mcmc.run(PRNGKey(0), y1=y1, y2=y2)
194
- inference_data = from_numpyro(mcmc)
195
- test_dict = {
196
- "posterior": ["x"],
197
- "sample_stats": ["diverging"],
198
- "log_likelihood": ["y1", "y2"],
199
- "observed_data": ["y1", "y2"],
200
- }
201
- fails = check_multiple_attrs(test_dict, inference_data)
202
- # from ..stats import waic
203
- # waic_results = waic(inference_data)
204
- # print(waic_results)
205
- # print(waic_results.keys())
206
- # print(waic_results.waic, waic_results.waic_se)
207
- assert not fails
208
- assert not hasattr(inference_data.sample_stats, "log_likelihood")
209
-
210
- def test_inference_data_constant_data(self):
211
- import numpyro
212
- import numpyro.distributions as dist
213
- from numpyro.infer import MCMC, NUTS
214
-
215
- x1 = 10
216
- x2 = 12
217
- y1 = np.random.randn(10)
218
-
219
- def model_constant_data(x, y1=None):
220
- _x = numpyro.sample("x", dist.Normal(1, 3))
221
- numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1)
222
-
223
- nuts_kernel = NUTS(model_constant_data)
224
- mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2)
225
- mcmc.run(PRNGKey(0), x=x1, y1=y1)
226
- posterior = mcmc.get_samples()
227
- posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1)
228
- predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2)
229
- inference_data = from_numpyro(
230
- mcmc,
231
- posterior_predictive=posterior_predictive,
232
- predictions=predictions,
233
- constant_data={"x1": x1},
234
- predictions_constant_data={"x2": x2},
235
- )
236
- test_dict = {
237
- "posterior": ["x"],
238
- "posterior_predictive": ["y1"],
239
- "sample_stats": ["diverging"],
240
- "log_likelihood": ["y1"],
241
- "predictions": ["y1"],
242
- "observed_data": ["y1"],
243
- "constant_data": ["x1"],
244
- "predictions_constant_data": ["x2"],
245
- }
246
- fails = check_multiple_attrs(test_dict, inference_data)
247
- assert not fails
248
-
249
- def test_inference_data_num_chains(self, predictions_data, chains):
250
- predictions = predictions_data
251
- inference_data = from_numpyro(predictions=predictions, num_chains=chains)
252
- nchains = inference_data.predictions.sizes["chain"]
253
- assert nchains == chains
254
-
255
- @pytest.mark.parametrize("nchains", [1, 2])
256
- @pytest.mark.parametrize("thin", [1, 2, 3, 5, 10])
257
- def test_mcmc_with_thinning(self, nchains, thin):
258
- import numpyro
259
- import numpyro.distributions as dist
260
- from numpyro.infer import MCMC, NUTS
261
-
262
- x = np.random.normal(10, 3, size=100)
263
-
264
- def model(x):
265
- numpyro.sample(
266
- "x",
267
- dist.Normal(
268
- numpyro.sample("loc", dist.Uniform(0, 20)),
269
- numpyro.sample("scale", dist.Uniform(0, 20)),
270
- ),
271
- obs=x,
272
- )
273
-
274
- nuts_kernel = NUTS(model)
275
- mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=400, num_chains=nchains, thinning=thin)
276
- mcmc.run(PRNGKey(0), x=x)
277
-
278
- inference_data = from_numpyro(mcmc)
279
- assert inference_data.posterior["loc"].shape == (nchains, 400 // thin)
280
-
281
- def test_mcmc_improper_uniform(self):
282
- import numpyro
283
- import numpyro.distributions as dist
284
- from numpyro.infer import MCMC, NUTS
285
-
286
- def model():
287
- x = numpyro.sample("x", dist.ImproperUniform(dist.constraints.positive, (), ()))
288
- return numpyro.sample("y", dist.Normal(x, 1), obs=1.0)
289
-
290
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
291
- mcmc.run(PRNGKey(0))
292
- inference_data = from_numpyro(mcmc)
293
- assert inference_data.observed_data
294
-
295
- def test_mcmc_infer_dims(self):
296
- import numpyro
297
- import numpyro.distributions as dist
298
- from numpyro.infer import MCMC, NUTS
299
-
300
- def model():
301
- # note: group2 gets assigned dim=-1 and group1 is assigned dim=-2
302
- with numpyro.plate("group2", 5), numpyro.plate("group1", 10):
303
- _ = numpyro.sample("param", dist.Normal(0, 1))
304
-
305
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
306
- mcmc.run(PRNGKey(0))
307
- inference_data = from_numpyro(
308
- mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
309
- )
310
- assert inference_data.posterior.param.dims == ("chain", "draw", "group1", "group2")
311
- assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
312
-
313
- def test_mcmc_infer_unsorted_dims(self):
314
- import numpyro
315
- import numpyro.distributions as dist
316
- from numpyro.infer import MCMC, NUTS
317
-
318
- def model():
319
- group1_plate = numpyro.plate("group1", 10, dim=-1)
320
- group2_plate = numpyro.plate("group2", 5, dim=-2)
321
-
322
- # the plate contexts are entered in a different order than the pre-defined dims
323
- # we should make sure this still works because the trace has all of the info it needs
324
- with group2_plate, group1_plate:
325
- _ = numpyro.sample("param", dist.Normal(0, 1))
326
-
327
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
328
- mcmc.run(PRNGKey(0))
329
- inference_data = from_numpyro(
330
- mcmc, coords={"group1": np.arange(10), "group2": np.arange(5)}
331
- )
332
- assert inference_data.posterior.param.dims == ("chain", "draw", "group2", "group1")
333
- assert all(dim in inference_data.posterior.param.coords for dim in ("group1", "group2"))
334
-
335
- def test_mcmc_infer_dims_no_coords(self):
336
- import numpyro
337
- import numpyro.distributions as dist
338
- from numpyro.infer import MCMC, NUTS
339
-
340
- def model():
341
- with numpyro.plate("group", 5):
342
- _ = numpyro.sample("param", dist.Normal(0, 1))
343
-
344
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
345
- mcmc.run(PRNGKey(0))
346
- inference_data = from_numpyro(mcmc)
347
- assert inference_data.posterior.param.dims == ("chain", "draw", "group")
348
-
349
- def test_mcmc_event_dims(self):
350
- import numpyro
351
- import numpyro.distributions as dist
352
- from numpyro.infer import MCMC, NUTS
353
-
354
- def model():
355
- _ = numpyro.sample(
356
- "gamma", dist.ZeroSumNormal(1, event_shape=(10,)), infer={"event_dims": ["groups"]}
357
- )
358
-
359
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
360
- mcmc.run(PRNGKey(0))
361
- inference_data = from_numpyro(mcmc, coords={"groups": np.arange(10)})
362
- assert inference_data.posterior.gamma.dims == ("chain", "draw", "groups")
363
- assert "groups" in inference_data.posterior.gamma.coords
364
-
365
- @pytest.mark.xfail
366
- def test_mcmc_inferred_dims_univariate(self):
367
- import numpyro
368
- import numpyro.distributions as dist
369
- from numpyro.infer import MCMC, NUTS
370
- import jax.numpy as jnp
371
-
372
- def model():
373
- alpha = numpyro.sample("alpha", dist.Normal(0, 1))
374
- sigma = numpyro.sample("sigma", dist.HalfNormal(1))
375
- with numpyro.plate("obs_idx", 3):
376
- # mu is plated by obs_idx, but isnt broadcasted to the plate shape
377
- # the expected behavior is that this should cause a failure
378
- mu = numpyro.deterministic("mu", alpha)
379
- return numpyro.sample("y", dist.Normal(mu, sigma), obs=jnp.array([-1, 0, 1]))
380
-
381
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
382
- mcmc.run(PRNGKey(0))
383
- inference_data = from_numpyro(mcmc, coords={"obs_idx": np.arange(3)})
384
- assert inference_data.posterior.mu.dims == ("chain", "draw", "obs_idx")
385
- assert "obs_idx" in inference_data.posterior.mu.coords
386
-
387
- def test_mcmc_extra_event_dims(self):
388
- import numpyro
389
- import numpyro.distributions as dist
390
- from numpyro.infer import MCMC, NUTS
391
-
392
- def model():
393
- gamma = numpyro.sample("gamma", dist.ZeroSumNormal(1, event_shape=(10,)))
394
- _ = numpyro.deterministic("gamma_plus1", gamma + 1)
395
-
396
- mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
397
- mcmc.run(PRNGKey(0))
398
- inference_data = from_numpyro(
399
- mcmc, coords={"groups": np.arange(10)}, extra_event_dims={"gamma_plus1": ["groups"]}
400
- )
401
- assert inference_data.posterior.gamma_plus1.dims == ("chain", "draw", "groups")
402
- assert "groups" in inference_data.posterior.gamma_plus1.coords
403
-
404
- def test_mcmc_predictions_infer_dims(
405
- self, data, eight_schools_params, predictions_data, predictions_params
406
- ):
407
- inference_data = self.get_inference_data(
408
- data, eight_schools_params, predictions_data, predictions_params, infer_dims=True
409
- )
410
- assert inference_data.predictions.obs.dims == ("chain", "draw", "J")
411
- assert "J" in inference_data.predictions.obs.coords
412
-
413
- def test_potential_energy_sign_conversion(self):
414
- """Test that potential energy is converted to log probability (lp) with correct sign."""
415
- import numpyro
416
- import numpyro.distributions as dist
417
- from numpyro.infer import MCMC, NUTS
418
-
419
- num_samples = 10
420
-
421
- def simple_model():
422
- numpyro.sample("x", dist.Normal(0, 1))
423
-
424
- nuts_kernel = NUTS(simple_model)
425
- mcmc = MCMC(nuts_kernel, num_samples=num_samples, num_warmup=5)
426
- mcmc.run(PRNGKey(0), extra_fields=["potential_energy"])
427
-
428
- # Get the raw extra fields from NumPyro
429
- extra_fields = mcmc.get_extra_fields(group_by_chain=True)
430
- # Convert to ArviZ InferenceData
431
- inference_data = from_numpyro(mcmc)
432
- arviz_lp = inference_data["sample_stats"]["lp"].values
433
-
434
- np.testing.assert_array_equal(arviz_lp, -extra_fields["potential_energy"])
@@ -1,119 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name, unused-import
2
- import sys
3
- import typing as tp
4
-
5
- import numpy as np
6
- import pytest
7
-
8
- from ... import InferenceData, from_pyjags, waic
9
- from ...data.io_pyjags import (
10
- _convert_arviz_dict_to_pyjags_dict,
11
- _convert_pyjags_dict_to_arviz_dict,
12
- _extract_arviz_dict_from_inference_data,
13
- )
14
- from ..helpers import check_multiple_attrs, eight_schools_params
15
-
16
- pytest.skip("Uses deprecated numpy C-api", allow_module_level=True)
17
-
18
- PYJAGS_POSTERIOR_DICT = {
19
- "b": np.random.randn(3, 10, 3),
20
- "int": np.random.randn(1, 10, 3),
21
- "log_like": np.random.randn(1, 10, 3),
22
- }
23
- PYJAGS_PRIOR_DICT = {"b": np.random.randn(3, 10, 3), "int": np.random.randn(1, 10, 3)}
24
-
25
-
26
- PARAMETERS = ("mu", "tau", "theta_tilde")
27
- VARIABLES = tuple(list(PARAMETERS) + ["log_like"])
28
-
29
- NUMBER_OF_WARMUP_SAMPLES = 1000
30
- NUMBER_OF_POST_WARMUP_SAMPLES = 5000
31
-
32
-
33
- def verify_equality_of_numpy_values_dictionaries(
34
- dict_1: tp.Mapping[tp.Any, np.ndarray], dict_2: tp.Mapping[tp.Any, np.ndarray]
35
- ) -> bool:
36
- if dict_1.keys() != dict_2.keys():
37
- return False
38
-
39
- for key in dict_1.keys():
40
- if not np.all(dict_1[key] == dict_2[key]):
41
- return False
42
-
43
- return True
44
-
45
-
46
- class TestDataPyJAGSWithoutEstimation:
47
- def test_convert_pyjags_samples_dictionary_to_arviz_samples_dictionary(self):
48
- arviz_samples_dict_from_pyjags_samples_dict = _convert_pyjags_dict_to_arviz_dict(
49
- PYJAGS_POSTERIOR_DICT
50
- )
51
-
52
- pyjags_dict_from_arviz_dict_from_pyjags_dict = _convert_arviz_dict_to_pyjags_dict(
53
- arviz_samples_dict_from_pyjags_samples_dict
54
- )
55
-
56
- assert verify_equality_of_numpy_values_dictionaries(
57
- PYJAGS_POSTERIOR_DICT,
58
- pyjags_dict_from_arviz_dict_from_pyjags_dict,
59
- )
60
-
61
- def test_extract_samples_dictionary_from_arviz_inference_data(self):
62
- arviz_samples_dict_from_pyjags_samples_dict = _convert_pyjags_dict_to_arviz_dict(
63
- PYJAGS_POSTERIOR_DICT
64
- )
65
-
66
- arviz_inference_data_from_pyjags_samples_dict = from_pyjags(PYJAGS_POSTERIOR_DICT)
67
- arviz_dict_from_idata_from_pyjags_dict = _extract_arviz_dict_from_inference_data(
68
- arviz_inference_data_from_pyjags_samples_dict
69
- )
70
-
71
- assert verify_equality_of_numpy_values_dictionaries(
72
- arviz_samples_dict_from_pyjags_samples_dict,
73
- arviz_dict_from_idata_from_pyjags_dict,
74
- )
75
-
76
- def test_roundtrip_from_pyjags_via_arviz_to_pyjags(self):
77
- arviz_inference_data_from_pyjags_samples_dict = from_pyjags(PYJAGS_POSTERIOR_DICT)
78
- arviz_dict_from_idata_from_pyjags_dict = _extract_arviz_dict_from_inference_data(
79
- arviz_inference_data_from_pyjags_samples_dict
80
- )
81
-
82
- pyjags_dict_from_arviz_idata = _convert_arviz_dict_to_pyjags_dict(
83
- arviz_dict_from_idata_from_pyjags_dict
84
- )
85
-
86
- assert verify_equality_of_numpy_values_dictionaries(
87
- PYJAGS_POSTERIOR_DICT, pyjags_dict_from_arviz_idata
88
- )
89
-
90
- @pytest.mark.parametrize("posterior", [None, PYJAGS_POSTERIOR_DICT])
91
- @pytest.mark.parametrize("prior", [None, PYJAGS_PRIOR_DICT])
92
- @pytest.mark.parametrize("save_warmup", [True, False])
93
- @pytest.mark.parametrize("warmup_iterations", [0, 5])
94
- def test_inference_data_attrs(self, posterior, prior, save_warmup, warmup_iterations: int):
95
- arviz_inference_data_from_pyjags_samples_dict = from_pyjags(
96
- posterior=posterior,
97
- prior=prior,
98
- log_likelihood={"y": "log_like"},
99
- save_warmup=save_warmup,
100
- warmup_iterations=warmup_iterations,
101
- )
102
- posterior_warmup_prefix = (
103
- "" if save_warmup and warmup_iterations > 0 and posterior is not None else "~"
104
- )
105
- prior_warmup_prefix = (
106
- "" if save_warmup and warmup_iterations > 0 and prior is not None else "~"
107
- )
108
- print(f'posterior_warmup_prefix="{posterior_warmup_prefix}"')
109
- test_dict = {
110
- f'{"~" if posterior is None else ""}posterior': ["b", "int"],
111
- f'{"~" if prior is None else ""}prior': ["b", "int"],
112
- f'{"~" if posterior is None else ""}log_likelihood': ["y"],
113
- f"{posterior_warmup_prefix}warmup_posterior": ["b", "int"],
114
- f"{prior_warmup_prefix}warmup_prior": ["b", "int"],
115
- f"{posterior_warmup_prefix}warmup_log_likelihood": ["y"],
116
- }
117
-
118
- fails = check_multiple_attrs(test_dict, arviz_inference_data_from_pyjags_samples_dict)
119
- assert not fails