arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/tests/helpers.py DELETED
@@ -1,677 +0,0 @@
1
- # pylint: disable=redefined-outer-name, comparison-with-callable, protected-access
2
- """Test helper functions."""
3
- import gzip
4
- import importlib
5
- import logging
6
- import os
7
- import sys
8
- from typing import Any, Dict, List, Optional, Tuple, Union
9
- import warnings
10
- from contextlib import contextmanager
11
-
12
- import cloudpickle
13
- import numpy as np
14
- import pytest
15
- from _pytest.outcomes import Skipped
16
- from packaging.version import Version
17
-
18
- from ..data import InferenceData, from_dict
19
-
20
- _log = logging.getLogger(__name__)
21
-
22
-
23
- class RandomVariableTestClass:
24
- """Example class for random variables."""
25
-
26
- def __init__(self, name):
27
- self.name = name
28
-
29
- def __repr__(self):
30
- """Return argument to constructor as string representation."""
31
- return self.name
32
-
33
-
34
- @contextmanager
35
- def does_not_warn(warning=Warning):
36
- with warnings.catch_warnings(record=True) as caught_warnings:
37
- warnings.simplefilter("always")
38
- yield
39
- for w in caught_warnings:
40
- if issubclass(w.category, warning):
41
- raise AssertionError(
42
- f"Expected no {warning.__name__} but caught warning with message: {w.message}"
43
- )
44
-
45
-
46
- @pytest.fixture(scope="module")
47
- def eight_schools_params():
48
- """Share setup for eight schools."""
49
- return {
50
- "J": 8,
51
- "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
52
- "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
53
- }
54
-
55
-
56
- @pytest.fixture(scope="module")
57
- def draws():
58
- """Share default draw count."""
59
- return 500
60
-
61
-
62
- @pytest.fixture(scope="module")
63
- def chains():
64
- """Share default chain count."""
65
- return 2
66
-
67
-
68
- def create_model(seed=10, transpose=False):
69
- """Create model with fake data."""
70
- np.random.seed(seed)
71
- nchains = 4
72
- ndraws = 500
73
- data = {
74
- "J": 8,
75
- "y": np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]),
76
- "sigma": np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]),
77
- }
78
- posterior = {
79
- "mu": np.random.randn(nchains, ndraws),
80
- "tau": abs(np.random.randn(nchains, ndraws)),
81
- "eta": np.random.randn(nchains, ndraws, data["J"]),
82
- "theta": np.random.randn(nchains, ndraws, data["J"]),
83
- }
84
- posterior_predictive = {"y": np.random.randn(nchains, ndraws, len(data["y"]))}
85
- sample_stats = {
86
- "energy": np.random.randn(nchains, ndraws),
87
- "diverging": np.random.randn(nchains, ndraws) > 0.90,
88
- "max_depth": np.random.randn(nchains, ndraws) > 0.90,
89
- }
90
- log_likelihood = {
91
- "y": np.random.randn(nchains, ndraws, data["J"]),
92
- }
93
- prior = {
94
- "mu": np.random.randn(nchains, ndraws) / 2,
95
- "tau": abs(np.random.randn(nchains, ndraws)) / 2,
96
- "eta": np.random.randn(nchains, ndraws, data["J"]) / 2,
97
- "theta": np.random.randn(nchains, ndraws, data["J"]) / 2,
98
- }
99
- prior_predictive = {"y": np.random.randn(nchains, ndraws, len(data["y"])) / 2}
100
- sample_stats_prior = {
101
- "energy": np.random.randn(nchains, ndraws),
102
- "diverging": (np.random.randn(nchains, ndraws) > 0.95).astype(int),
103
- }
104
- model = from_dict(
105
- posterior=posterior,
106
- posterior_predictive=posterior_predictive,
107
- sample_stats=sample_stats,
108
- log_likelihood=log_likelihood,
109
- prior=prior,
110
- prior_predictive=prior_predictive,
111
- sample_stats_prior=sample_stats_prior,
112
- observed_data={"y": data["y"]},
113
- dims={
114
- "y": ["obs_dim"],
115
- "log_likelihood": ["obs_dim"],
116
- "theta": ["school"],
117
- "eta": ["school"],
118
- },
119
- coords={"obs_dim": range(data["J"])},
120
- )
121
- if transpose:
122
- for group in model._groups:
123
- group_dataset = getattr(model, group)
124
- if all(dim in group_dataset.dims for dim in ("draw", "chain")):
125
- setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
126
- return model
127
-
128
-
129
- def create_multidimensional_model(seed=10, transpose=False):
130
- """Create model with fake data."""
131
- np.random.seed(seed)
132
- nchains = 4
133
- ndraws = 500
134
- ndim1 = 5
135
- ndim2 = 7
136
- data = {
137
- "y": np.random.normal(size=(ndim1, ndim2)),
138
- "sigma": np.random.normal(size=(ndim1, ndim2)),
139
- }
140
- posterior = {
141
- "mu": np.random.randn(nchains, ndraws),
142
- "tau": abs(np.random.randn(nchains, ndraws)),
143
- "eta": np.random.randn(nchains, ndraws, ndim1, ndim2),
144
- "theta": np.random.randn(nchains, ndraws, ndim1, ndim2),
145
- }
146
- posterior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2)}
147
- sample_stats = {
148
- "energy": np.random.randn(nchains, ndraws),
149
- "diverging": np.random.randn(nchains, ndraws) > 0.90,
150
- }
151
- log_likelihood = {
152
- "y": np.random.randn(nchains, ndraws, ndim1, ndim2),
153
- }
154
- prior = {
155
- "mu": np.random.randn(nchains, ndraws) / 2,
156
- "tau": abs(np.random.randn(nchains, ndraws)) / 2,
157
- "eta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2,
158
- "theta": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2,
159
- }
160
- prior_predictive = {"y": np.random.randn(nchains, ndraws, ndim1, ndim2) / 2}
161
- sample_stats_prior = {
162
- "energy": np.random.randn(nchains, ndraws),
163
- "diverging": (np.random.randn(nchains, ndraws) > 0.95).astype(int),
164
- }
165
- model = from_dict(
166
- posterior=posterior,
167
- posterior_predictive=posterior_predictive,
168
- sample_stats=sample_stats,
169
- log_likelihood=log_likelihood,
170
- prior=prior,
171
- prior_predictive=prior_predictive,
172
- sample_stats_prior=sample_stats_prior,
173
- observed_data={"y": data["y"]},
174
- dims={"y": ["dim1", "dim2"], "log_likelihood": ["dim1", "dim2"]},
175
- coords={"dim1": range(ndim1), "dim2": range(ndim2)},
176
- )
177
- if transpose:
178
- for group in model._groups:
179
- group_dataset = getattr(model, group)
180
- if all(dim in group_dataset.dims for dim in ("draw", "chain")):
181
- setattr(model, group, group_dataset.transpose(*["draw", "chain"], ...))
182
- return model
183
-
184
-
185
- def create_data_random(groups=None, seed=10):
186
- """Create InferenceData object using random data."""
187
- if groups is None:
188
- groups = ["posterior", "sample_stats", "observed_data", "posterior_predictive"]
189
- rng = np.random.default_rng(seed)
190
- data = rng.normal(size=(4, 500, 8))
191
- idata_dict = dict(
192
- posterior={"a": data[..., 0], "b": data},
193
- sample_stats={"a": data[..., 0], "b": data},
194
- observed_data={"b": data[0, 0, :]},
195
- posterior_predictive={"a": data[..., 0], "b": data},
196
- prior={"a": data[..., 0], "b": data},
197
- prior_predictive={"a": data[..., 0], "b": data},
198
- warmup_posterior={"a": data[..., 0], "b": data},
199
- warmup_posterior_predictive={"a": data[..., 0], "b": data},
200
- warmup_prior={"a": data[..., 0], "b": data},
201
- )
202
- idata = from_dict(
203
- **{group: ary for group, ary in idata_dict.items() if group in groups}, save_warmup=True
204
- )
205
- return idata
206
-
207
-
208
- @pytest.fixture()
209
- def data_random():
210
- """Fixture containing InferenceData object using random data."""
211
- idata = create_data_random()
212
- return idata
213
-
214
-
215
- @pytest.fixture(scope="module")
216
- def models():
217
- """Fixture containing 2 mock inference data instances for testing."""
218
- # blank line to keep black and pydocstyle happy
219
-
220
- class Models:
221
- model_1 = create_model(seed=10)
222
- model_2 = create_model(seed=11, transpose=True)
223
-
224
- return Models()
225
-
226
-
227
- @pytest.fixture(scope="module")
228
- def multidim_models():
229
- """Fixture containing 2 mock inference data instances with multidimensional data for testing."""
230
- # blank line to keep black and pydocstyle happy
231
-
232
- class Models:
233
- model_1 = create_multidimensional_model(seed=10)
234
- model_2 = create_multidimensional_model(seed=11, transpose=True)
235
-
236
- return Models()
237
-
238
-
239
- def check_multiple_attrs(
240
- test_dict: Dict[str, List[str]], parent: InferenceData
241
- ) -> List[Union[str, Tuple[str, str]]]:
242
- """Perform multiple hasattr checks on InferenceData objects.
243
-
244
- It is thought to first check if the parent object contains a given dataset,
245
- and then (if present) check the attributes of the dataset.
246
-
247
- Given the output of the function, all mismatches between expectation and reality can
248
- be retrieved: a single string indicates a group mismatch and a tuple of strings
249
- ``(group, var)`` indicates a mismatch in the variable ``var`` of ``group``.
250
-
251
- Parameters
252
- ----------
253
- test_dict: dict of {str : list of str}
254
- Its structure should be `{dataset1_name: [var1, var2], dataset2_name: [var]}`.
255
- A ``~`` at the beginning of a dataset or variable name indicates the name NOT
256
- being present must be asserted.
257
- parent: InferenceData
258
- InferenceData object on which to check the attributes.
259
-
260
- Returns
261
- -------
262
- list
263
- List containing the failed checks. It will contain either the dataset_name or a
264
- tuple (dataset_name, var) for all non present attributes.
265
-
266
- Examples
267
- --------
268
- The output below indicates that ``posterior`` group was expected but not found, and
269
- variables ``a`` and ``b``:
270
-
271
- ["posterior", ("prior", "a"), ("prior", "b")]
272
-
273
- Another example could be the following:
274
-
275
- [("posterior", "a"), "~observed_data", ("sample_stats", "~log_likelihood")]
276
-
277
- In this case, the output indicates that variable ``a`` was not found in ``posterior``
278
- as it was expected, however, in the other two cases, the preceding ``~`` (kept from the
279
- input negation notation) indicates that ``observed_data`` group should not be present
280
- but was found in the InferenceData and that ``log_likelihood`` variable was found
281
- in ``sample_stats``, also against what was expected.
282
-
283
- """
284
- failed_attrs: List[Union[str, Tuple[str, str]]] = []
285
- for dataset_name, attributes in test_dict.items():
286
- if dataset_name.startswith("~"):
287
- if hasattr(parent, dataset_name[1:]):
288
- failed_attrs.append(dataset_name)
289
- elif hasattr(parent, dataset_name):
290
- dataset = getattr(parent, dataset_name)
291
- for attribute in attributes:
292
- if attribute.startswith("~"):
293
- if hasattr(dataset, attribute[1:]):
294
- failed_attrs.append((dataset_name, attribute))
295
- elif not hasattr(dataset, attribute):
296
- failed_attrs.append((dataset_name, attribute))
297
- else:
298
- failed_attrs.append(dataset_name)
299
- return failed_attrs
300
-
301
-
302
- def emcee_version():
303
- """Check emcee version.
304
-
305
- Returns
306
- -------
307
- int
308
- Major version number
309
-
310
- """
311
- import emcee
312
-
313
- return int(emcee.__version__[0])
314
-
315
-
316
- def needs_emcee3_func():
317
- """Check if emcee3 is required."""
318
- # pylint: disable=invalid-name
319
- needs_emcee3 = pytest.mark.skipif(emcee_version() < 3, reason="emcee3 required")
320
- return needs_emcee3
321
-
322
-
323
- def _emcee_lnprior(theta):
324
- """Proper function to allow pickling."""
325
- mu, tau, eta = theta[0], theta[1], theta[2:]
326
- # Half-cauchy prior, hwhm=25
327
- if tau < 0:
328
- return -np.inf
329
- prior_tau = -np.log(tau**2 + 25**2)
330
- prior_mu = -((mu / 10) ** 2) # normal prior, loc=0, scale=10
331
- prior_eta = -np.sum(eta**2) # normal prior, loc=0, scale=1
332
- return prior_mu + prior_tau + prior_eta
333
-
334
-
335
- def _emcee_lnprob(theta, y, sigma):
336
- """Proper function to allow pickling."""
337
- mu, tau, eta = theta[0], theta[1], theta[2:]
338
- prior = _emcee_lnprior(theta)
339
- like_vect = -(((mu + tau * eta - y) / sigma) ** 2)
340
- like = np.sum(like_vect)
341
- return like + prior, (like_vect, np.random.normal((mu + tau * eta), sigma))
342
-
343
-
344
- def emcee_schools_model(data, draws, chains):
345
- """Schools model in emcee."""
346
- import emcee
347
-
348
- chains = 10 * chains # emcee is sad with too few walkers
349
- y = data["y"]
350
- sigma = data["sigma"]
351
- J = data["J"] # pylint: disable=invalid-name
352
- ndim = J + 2
353
-
354
- pos = np.random.normal(size=(chains, ndim))
355
- pos[:, 1] = np.absolute(pos[:, 1]) # pylint: disable=unsupported-assignment-operation
356
-
357
- if emcee_version() < 3:
358
- sampler = emcee.EnsembleSampler(chains, ndim, _emcee_lnprob, args=(y, sigma))
359
- # pylint: enable=unexpected-keyword-arg
360
- sampler.run_mcmc(pos, draws)
361
- else:
362
- here = os.path.dirname(os.path.abspath(__file__))
363
- data_directory = os.path.join(here, "saved_models")
364
- filepath = os.path.join(data_directory, "reader_testfile.h5")
365
- backend = emcee.backends.HDFBackend(filepath) # pylint: disable=no-member
366
- backend.reset(chains, ndim)
367
- # pylint: disable=unexpected-keyword-arg
368
- sampler = emcee.EnsembleSampler(
369
- chains, ndim, _emcee_lnprob, args=(y, sigma), backend=backend
370
- )
371
- # pylint: enable=unexpected-keyword-arg
372
- sampler.run_mcmc(pos, draws, store=True)
373
- return sampler
374
-
375
-
376
- # pylint:disable=no-member,no-value-for-parameter,invalid-name
377
- def _pyro_noncentered_model(J, sigma, y=None):
378
- import pyro
379
- import pyro.distributions as dist
380
-
381
- mu = pyro.sample("mu", dist.Normal(0, 5))
382
- tau = pyro.sample("tau", dist.HalfCauchy(5))
383
- with pyro.plate("J", J):
384
- eta = pyro.sample("eta", dist.Normal(0, 1))
385
- theta = mu + tau * eta
386
- return pyro.sample("obs", dist.Normal(theta, sigma), obs=y)
387
-
388
-
389
- def pyro_noncentered_schools(data, draws, chains):
390
- """Non-centered eight schools implementation in Pyro."""
391
- import torch
392
- from pyro.infer import MCMC, NUTS
393
-
394
- y = torch.from_numpy(data["y"]).float()
395
- sigma = torch.from_numpy(data["sigma"]).float()
396
-
397
- nuts_kernel = NUTS(_pyro_noncentered_model, jit_compile=True, ignore_jit_warnings=True)
398
- posterior = MCMC(nuts_kernel, num_samples=draws, warmup_steps=draws, num_chains=chains)
399
- posterior.run(data["J"], sigma, y)
400
-
401
- # This block lets the posterior be pickled
402
- posterior.sampler = None
403
- posterior.kernel.potential_fn = None
404
- return posterior
405
-
406
-
407
- # pylint:disable=no-member,no-value-for-parameter,invalid-name
408
- def _numpyro_noncentered_model(J, sigma, y=None):
409
- import numpyro
410
- import numpyro.distributions as dist
411
-
412
- mu = numpyro.sample("mu", dist.Normal(0, 5))
413
- tau = numpyro.sample("tau", dist.HalfCauchy(5))
414
- with numpyro.plate("J", J):
415
- eta = numpyro.sample("eta", dist.Normal(0, 1))
416
- theta = mu + tau * eta
417
- return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
418
-
419
-
420
- def numpyro_schools_model(data, draws, chains):
421
- """Centered eight schools implementation in NumPyro."""
422
- from jax.random import PRNGKey
423
- from numpyro.infer import MCMC, NUTS
424
-
425
- mcmc = MCMC(
426
- NUTS(_numpyro_noncentered_model),
427
- num_warmup=draws,
428
- num_samples=draws,
429
- num_chains=chains,
430
- chain_method="sequential",
431
- )
432
- mcmc.run(PRNGKey(0), extra_fields=("num_steps", "energy"), **data)
433
-
434
- # This block lets the posterior be pickled
435
- mcmc.sampler._sample_fn = None # pylint: disable=protected-access
436
- mcmc.sampler._init_fn = None # pylint: disable=protected-access
437
- mcmc.sampler._postprocess_fn = None # pylint: disable=protected-access
438
- mcmc.sampler._potential_fn = None # pylint: disable=protected-access
439
- mcmc.sampler._potential_fn_gen = None # pylint: disable=protected-access
440
- mcmc._cache = {} # pylint: disable=protected-access
441
- return mcmc
442
-
443
-
444
- def pystan_noncentered_schools(data, draws, chains):
445
- """Non-centered eight schools implementation for pystan."""
446
- schools_code = """
447
- data {
448
- int<lower=0> J;
449
- array[J] real y;
450
- array[J] real<lower=0> sigma;
451
- }
452
-
453
- parameters {
454
- real mu;
455
- real<lower=0> tau;
456
- array[J] real eta;
457
- }
458
-
459
- transformed parameters {
460
- array[J] real theta;
461
- for (j in 1:J)
462
- theta[j] = mu + tau * eta[j];
463
- }
464
-
465
- model {
466
- mu ~ normal(0, 5);
467
- tau ~ cauchy(0, 5);
468
- eta ~ normal(0, 1);
469
- y ~ normal(theta, sigma);
470
- }
471
-
472
- generated quantities {
473
- array[J] real log_lik;
474
- array[J] real y_hat;
475
- for (j in 1:J) {
476
- log_lik[j] = normal_lpdf(y[j] | theta[j], sigma[j]);
477
- y_hat[j] = normal_rng(theta[j], sigma[j]);
478
- }
479
- }
480
- """
481
- if pystan_version() == 2:
482
- import pystan # pylint: disable=import-error
483
-
484
- stan_model = pystan.StanModel(model_code=schools_code)
485
- fit = stan_model.sampling(
486
- data=data,
487
- iter=draws + 500,
488
- warmup=500,
489
- chains=chains,
490
- check_hmc_diagnostics=False,
491
- control=dict(adapt_engaged=False),
492
- )
493
- else:
494
- import stan # pylint: disable=import-error
495
-
496
- stan_model = stan.build(schools_code, data=data)
497
- fit = stan_model.sample(
498
- num_chains=chains, num_samples=draws, num_warmup=500, save_warmup=True
499
- )
500
- return stan_model, fit
501
-
502
-
503
- def bm_schools_model(data, draws, chains):
504
- import beanmachine.ppl as bm # pylint: disable=import-error
505
- import torch
506
- import torch.distributions as dist
507
-
508
- class EightSchools:
509
- @bm.random_variable
510
- def mu(self):
511
- return dist.Normal(0, 5)
512
-
513
- @bm.random_variable
514
- def tau(self):
515
- return dist.HalfCauchy(5)
516
-
517
- @bm.random_variable
518
- def eta(self):
519
- return dist.Normal(0, 1).expand((data["J"],))
520
-
521
- @bm.functional
522
- def theta(self):
523
- return self.mu() + self.tau() * self.eta()
524
-
525
- @bm.random_variable
526
- def obs(self):
527
- return dist.Normal(self.theta(), torch.from_numpy(data["sigma"]).float())
528
-
529
- model = EightSchools()
530
-
531
- prior = bm.GlobalNoUTurnSampler().infer(
532
- queries=[model.mu(), model.tau(), model.eta()],
533
- observations={},
534
- num_samples=draws,
535
- num_adaptive_samples=500,
536
- num_chains=chains,
537
- )
538
-
539
- posterior = bm.GlobalNoUTurnSampler().infer(
540
- queries=[model.mu(), model.tau(), model.eta()],
541
- observations={model.obs(): torch.from_numpy(data["y"]).float()},
542
- num_samples=draws,
543
- num_adaptive_samples=500,
544
- num_chains=chains,
545
- )
546
- return model, prior, posterior
547
-
548
-
549
- def library_handle(library):
550
- """Import a library and return the handle."""
551
- if library == "pystan":
552
- try:
553
- module = importlib.import_module("pystan")
554
- except ImportError:
555
- module = importlib.import_module("stan")
556
- else:
557
- module = importlib.import_module(library)
558
- return module
559
-
560
-
561
- def load_cached_models(eight_schools_data, draws, chains, libs=None):
562
- """Load pystan, emcee, and pyro models from pickle."""
563
- here = os.path.dirname(os.path.abspath(__file__))
564
- supported = (
565
- ("pystan", pystan_noncentered_schools),
566
- ("emcee", emcee_schools_model),
567
- ("pyro", pyro_noncentered_schools),
568
- ("numpyro", numpyro_schools_model),
569
- # ("beanmachine", bm_schools_model), # ignore beanmachine until it supports torch>=2
570
- )
571
- data_directory = os.path.join(here, "saved_models")
572
- models = {}
573
-
574
- if isinstance(libs, str):
575
- libs = [libs]
576
-
577
- for library_name, func in supported:
578
- if libs is not None and library_name not in libs:
579
- continue
580
- library = library_handle(library_name)
581
- if library.__name__ == "stan":
582
- # PyStan3 does not support pickling
583
- # httpstan caches models automatically
584
- _log.info("Generating and loading stan model")
585
- models["pystan"] = func(eight_schools_data, draws, chains)
586
- continue
587
-
588
- py_version = sys.version_info
589
- fname = "{0.major}.{0.minor}_{1.__name__}_{1.__version__}_{2}_{3}_{4}.pkl.gzip".format(
590
- py_version, library, sys.platform, draws, chains
591
- )
592
-
593
- path = os.path.join(data_directory, fname)
594
- if not os.path.exists(path):
595
- with gzip.open(path, "wb") as buff:
596
- try:
597
- _log.info("Generating and caching %s", fname)
598
- cloudpickle.dump(func(eight_schools_data, draws, chains), buff)
599
- except AttributeError as err:
600
- raise AttributeError(f"Failed caching {library_name}") from err
601
-
602
- with gzip.open(path, "rb") as buff:
603
- _log.info("Loading %s from cache", fname)
604
- models[library.__name__] = cloudpickle.load(buff)
605
-
606
- return models
607
-
608
-
609
- def pystan_version():
610
- """Check PyStan version.
611
-
612
- Returns
613
- -------
614
- int
615
- Major version number
616
-
617
- """
618
- try:
619
- import pystan # pylint: disable=import-error
620
-
621
- version = int(pystan.__version__[0])
622
- except ImportError:
623
- try:
624
- import stan # pylint: disable=import-error
625
-
626
- version = int(stan.__version__[0])
627
- except ImportError:
628
- version = None
629
- return version
630
-
631
-
632
- def test_precompile_models(eight_schools_params, draws, chains):
633
- """Precompile model files."""
634
- load_cached_models(eight_schools_params, draws, chains)
635
-
636
-
637
- def importorskip(
638
- modname: str, minversion: Optional[str] = None, reason: Optional[str] = None
639
- ) -> Any:
640
- """Import and return the requested module ``modname``.
641
-
642
- Doesn't allow skips on CI machine.
643
- Borrowed and modified from ``pytest.importorskip``.
644
- :param str modname: the name of the module to import
645
- :param str minversion: if given, the imported module's ``__version__``
646
- attribute must be at least this minimal version, otherwise the test is
647
- still skipped.
648
- :param str reason: if given, this reason is shown as the message when the
649
- module cannot be imported.
650
- :returns: The imported module. This should be assigned to its canonical
651
- name.
652
- Example::
653
- docutils = pytest.importorskip("docutils")
654
- """
655
- # Unless ARVIZ_REQUIRE_ALL_DEPS is defined, tests that require a missing dependency are skipped
656
- # if set, missing optional dependencies trigger failed tests.
657
- if "ARVIZ_REQUIRE_ALL_DEPS" not in os.environ:
658
- return pytest.importorskip(modname=modname, minversion=minversion, reason=reason)
659
-
660
- compile(modname, "", "eval") # to catch syntaxerrors
661
-
662
- with warnings.catch_warnings():
663
- # make sure to ignore ImportWarnings that might happen because
664
- # of existing directories with the same name we're trying to
665
- # import but without a __init__.py file
666
- warnings.simplefilter("ignore")
667
- __import__(modname)
668
- mod = sys.modules[modname]
669
- if minversion is None:
670
- return mod
671
- verattr = getattr(mod, "__version__", None)
672
- if verattr is None or Version(verattr) < Version(minversion):
673
- raise Skipped(
674
- "module %r has __version__ %r, required is: %r" % (modname, verattr, minversion),
675
- allow_module_level=True,
676
- )
677
- return mod