arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,376 +0,0 @@
1
- """Tests for arviz.utils."""
2
-
3
- # pylint: disable=redefined-outer-name, no-member
4
- from unittest.mock import Mock
5
-
6
- import numpy as np
7
- import pytest
8
- import scipy.stats as st
9
-
10
- from ...data import dict_to_dataset, from_dict, load_arviz_data
11
- from ...stats.density_utils import _circular_mean, _normalize_angle, _find_hdi_contours
12
- from ...utils import (
13
- _stack,
14
- _subset_list,
15
- _var_names,
16
- expand_dims,
17
- flatten_inference_data_to_dict,
18
- one_de,
19
- two_de,
20
- )
21
- from ..helpers import RandomVariableTestClass
22
-
23
-
24
- @pytest.fixture(scope="session")
25
- def inference_data():
26
- centered_eight = load_arviz_data("centered_eight")
27
- return centered_eight
28
-
29
-
30
- @pytest.fixture(scope="session")
31
- def data():
32
- centered_eight = load_arviz_data("centered_eight")
33
- return centered_eight.posterior
34
-
35
-
36
- @pytest.mark.parametrize(
37
- "var_names_expected",
38
- [
39
- ("mu", ["mu"]),
40
- (None, None),
41
- (["mu", "tau"], ["mu", "tau"]),
42
- ("~mu", ["theta", "tau"]),
43
- (["~mu"], ["theta", "tau"]),
44
- ],
45
- )
46
- def test_var_names(var_names_expected, data):
47
- """Test var_name handling"""
48
- var_names, expected = var_names_expected
49
- assert _var_names(var_names, data) == expected
50
-
51
-
52
- def test_var_names_warning():
53
- """Test confusing var_name handling"""
54
- data = from_dict(
55
- posterior={
56
- "~mu": np.random.randn(2, 10),
57
- "mu": -np.random.randn(2, 10), # pylint: disable=invalid-unary-operand-type
58
- "theta": np.random.randn(2, 10, 8),
59
- }
60
- ).posterior
61
- var_names = expected = ["~mu"]
62
- with pytest.warns(UserWarning):
63
- assert _var_names(var_names, data) == expected
64
-
65
-
66
- def test_var_names_key_error(data):
67
- with pytest.raises(KeyError, match="bad_var_name"):
68
- _var_names(("theta", "tau", "bad_var_name"), data)
69
-
70
-
71
- @pytest.mark.parametrize(
72
- "var_args",
73
- [
74
- (["ta"], ["beta1", "beta2", "theta"], "like"),
75
- (["~beta"], ["phi", "theta"], "like"),
76
- (["beta[0-9]+"], ["beta1", "beta2"], "regex"),
77
- (["^p"], ["phi"], "regex"),
78
- (["~^t"], ["beta1", "beta2", "phi"], "regex"),
79
- ],
80
- )
81
- def test_var_names_filter_multiple_input(var_args):
82
- samples = np.random.randn(10)
83
- data1 = dict_to_dataset({"beta1": samples, "beta2": samples, "phi": samples})
84
- data2 = dict_to_dataset({"beta1": samples, "beta2": samples, "theta": samples})
85
- data = [data1, data2]
86
- var_names, expected, filter_vars = var_args
87
- assert _var_names(var_names, data, filter_vars) == expected
88
-
89
-
90
- @pytest.mark.parametrize(
91
- "var_args",
92
- [
93
- (["alpha", "beta"], ["alpha", "beta1", "beta2"], "like"),
94
- (["~beta"], ["alpha", "p1", "p2", "phi", "theta", "theta_t"], "like"),
95
- (["theta"], ["theta", "theta_t"], "like"),
96
- (["~theta"], ["alpha", "beta1", "beta2", "p1", "p2", "phi"], "like"),
97
- (["p"], ["alpha", "p1", "p2", "phi"], "like"),
98
- (["~p"], ["beta1", "beta2", "theta", "theta_t"], "like"),
99
- (["^bet"], ["beta1", "beta2"], "regex"),
100
- (["^p"], ["p1", "p2", "phi"], "regex"),
101
- (["~^p"], ["alpha", "beta1", "beta2", "theta", "theta_t"], "regex"),
102
- (["p[0-9]+"], ["p1", "p2"], "regex"),
103
- (["~p[0-9]+"], ["alpha", "beta1", "beta2", "phi", "theta", "theta_t"], "regex"),
104
- ],
105
- )
106
- def test_var_names_filter(var_args):
107
- """Test var_names filter with partial naming or regular expressions."""
108
- samples = np.random.randn(10)
109
- data = dict_to_dataset(
110
- {
111
- "alpha": samples,
112
- "beta1": samples,
113
- "beta2": samples,
114
- "p1": samples,
115
- "p2": samples,
116
- "phi": samples,
117
- "theta": samples,
118
- "theta_t": samples,
119
- }
120
- )
121
- var_names, expected, filter_vars = var_args
122
- assert _var_names(var_names, data, filter_vars) == expected
123
-
124
-
125
- def test_nonstring_var_names():
126
- """Check that non-string variables are preserved"""
127
- mu = RandomVariableTestClass("mu")
128
- samples = np.random.randn(10)
129
- data = dict_to_dataset({mu: samples})
130
- assert _var_names([mu], data) == [mu]
131
-
132
-
133
- def test_var_names_filter_invalid_argument():
134
- """Check invalid argument raises."""
135
- samples = np.random.randn(10)
136
- data = dict_to_dataset({"alpha": samples})
137
- msg = r"^\'filter_vars\' can only be None, \'like\', or \'regex\', got: 'foo'$"
138
- with pytest.raises(ValueError, match=msg):
139
- assert _var_names(["alpha"], data, filter_vars="foo")
140
-
141
-
142
- def test_subset_list_negation_not_found():
143
- """Check there is a warning if negation pattern is ignored"""
144
- names = ["mu", "theta"]
145
- with pytest.warns(UserWarning, match=".+not.+found.+"):
146
- assert _subset_list("~tau", names) == names
147
-
148
-
149
- @pytest.fixture(scope="function")
150
- def utils_with_numba_import_fail(monkeypatch):
151
- """Patch numba in utils so when its imported it raises ImportError"""
152
- failed_import = Mock()
153
- failed_import.side_effect = ImportError
154
-
155
- from ... import utils
156
-
157
- monkeypatch.setattr(utils.importlib, "import_module", failed_import)
158
- return utils
159
-
160
-
161
- def test_conditional_jit_decorator_no_numba(utils_with_numba_import_fail):
162
- """Tests to see if Numba jit code block is skipped with Import Failure
163
-
164
- Test can be distinguished from test_conditional_jit__numba_decorator
165
- by use of debugger or coverage tool
166
- """
167
-
168
- @utils_with_numba_import_fail.conditional_jit
169
- def func():
170
- return "Numba not used"
171
-
172
- assert func() == "Numba not used"
173
-
174
-
175
- def test_conditional_vect_decorator_no_numba(utils_with_numba_import_fail):
176
- """Tests to see if Numba vectorize code block is skipped with Import Failure
177
-
178
- Test can be distinguished from test_conditional_vect__numba_decorator
179
- by use of debugger or coverage tool
180
- """
181
-
182
- @utils_with_numba_import_fail.conditional_vect
183
- def func():
184
- return "Numba not used"
185
-
186
- assert func() == "Numba not used"
187
-
188
-
189
- def test_conditional_jit_numba_decorator():
190
- """Tests to see if Numba is used.
191
-
192
- Test can be distinguished from test_conditional_jit_decorator_no_numba
193
- by use of debugger or coverage tool
194
- """
195
- from ... import utils
196
-
197
- @utils.conditional_jit
198
- def func():
199
- return True
200
-
201
- assert func()
202
-
203
-
204
- def test_conditional_vect_numba_decorator():
205
- """Tests to see if Numba is used.
206
-
207
- Test can be distinguished from test_conditional_jit_decorator_no_numba
208
- by use of debugger or coverage tool
209
- """
210
- from ... import utils
211
-
212
- @utils.conditional_vect
213
- def func(a_a, b_b):
214
- return a_a + b_b
215
-
216
- value_one = np.random.randn(10)
217
- value_two = np.random.randn(10)
218
- assert np.allclose(func(value_one, value_two), value_one + value_two)
219
-
220
-
221
- def test_conditional_vect_numba_decorator_keyword(monkeypatch):
222
- """Checks else statement and vect keyword argument"""
223
- from ... import utils
224
-
225
- # Mock import lib to return numba with hit method which returns a function that returns kwargs
226
- numba_mock = Mock()
227
- monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock)
228
-
229
- def vectorize(**kwargs):
230
- """overwrite numba.vectorize function"""
231
- return lambda x: (x(), kwargs)
232
-
233
- numba_mock.vectorize = vectorize
234
-
235
- @utils.conditional_vect(keyword_argument="A keyword argument")
236
- def placeholder_func():
237
- """This function does nothing"""
238
- return "output"
239
-
240
- # pylint: disable=unpacking-non-sequence
241
- function_results, wrapper_result = placeholder_func
242
- assert wrapper_result == {"keyword_argument": "A keyword argument"}
243
- assert function_results == "output"
244
-
245
-
246
- def test_stack():
247
- x = np.random.randn(10, 4, 6)
248
- y = np.random.randn(100, 4, 6)
249
- assert x.shape[1:] == y.shape[1:]
250
- assert np.allclose(np.vstack((x, y)), _stack(x, y))
251
- assert _stack
252
-
253
-
254
- @pytest.mark.parametrize("data", [np.random.randn(1000), np.random.randn(1000).tolist()])
255
- def test_two_de(data):
256
- """Test to check for custom atleast_2d. List added to test for a non ndarray case."""
257
- assert np.allclose(two_de(data), np.atleast_2d(data))
258
-
259
-
260
- @pytest.mark.parametrize("data", [np.random.randn(100), np.random.randn(100).tolist()])
261
- def test_one_de(data):
262
- """Test to check for custom atleast_1d. List added to test for a non ndarray case."""
263
- assert np.allclose(one_de(data), np.atleast_1d(data))
264
-
265
-
266
- @pytest.mark.parametrize("data", [np.random.randn(100), np.random.randn(100).tolist()])
267
- def test_expand_dims(data):
268
- """Test to check for custom expand_dims. List added to test for a non ndarray case."""
269
- assert np.allclose(expand_dims(data), np.expand_dims(data, 0))
270
-
271
-
272
- @pytest.mark.parametrize("var_names", [None, "mu", ["mu", "tau"]])
273
- @pytest.mark.parametrize(
274
- "groups", [None, "posterior_groups", "prior_groups", ["posterior", "sample_stats"]]
275
- )
276
- @pytest.mark.parametrize("dimensions", [None, "draw", ["chain", "draw"]])
277
- @pytest.mark.parametrize("group_info", [True, False])
278
- @pytest.mark.parametrize(
279
- "var_name_format", [None, "brackets", "underscore", "cds", ((",", "[", "]"), ("_", ""))]
280
- )
281
- @pytest.mark.parametrize("index_origin", [None, 0, 1])
282
- def test_flatten_inference_data_to_dict(
283
- inference_data, var_names, groups, dimensions, group_info, var_name_format, index_origin
284
- ):
285
- """Test flattening (stacking) inference data (subgroups) for dictionary."""
286
- res_dict = flatten_inference_data_to_dict(
287
- data=inference_data,
288
- var_names=var_names,
289
- groups=groups,
290
- dimensions=dimensions,
291
- group_info=group_info,
292
- var_name_format=var_name_format,
293
- index_origin=index_origin,
294
- )
295
- assert res_dict
296
- assert "draw" in res_dict
297
- assert any("mu" in item for item in res_dict)
298
- if group_info:
299
- if groups != "prior_groups":
300
- assert any("posterior" in item for item in res_dict)
301
- if var_names is None:
302
- assert any("sample_stats" in item for item in res_dict)
303
- else:
304
- assert any("prior" in item for item in res_dict)
305
- elif groups == "prior_groups":
306
- assert all("prior" not in item for item in res_dict)
307
-
308
- else:
309
- assert all("posterior" not in item for item in res_dict)
310
- if var_names is None:
311
- assert all("sample_stats" not in item for item in res_dict)
312
-
313
-
314
- @pytest.mark.parametrize("mean", [0, np.pi, 4 * np.pi, -2 * np.pi, -10 * np.pi])
315
- def test_circular_mean_scipy(mean):
316
- """Test our `_circular_mean()` function gives same result than Scipy version."""
317
- rvs = st.vonmises.rvs(loc=mean, kappa=1, size=1000)
318
- mean_az = _circular_mean(rvs)
319
- mean_sp = st.circmean(rvs, low=-np.pi, high=np.pi)
320
- np.testing.assert_almost_equal(mean_az, mean_sp)
321
-
322
-
323
- @pytest.mark.parametrize("mean", [0, np.pi, 4 * np.pi, -2 * np.pi, -10 * np.pi])
324
- def test_normalize_angle(mean):
325
- """Testing _normalize_angles() return values between expected bounds"""
326
- rvs = st.vonmises.rvs(loc=mean, kappa=1, size=1000)
327
- values = _normalize_angle(rvs, zero_centered=True)
328
- assert ((-np.pi <= values) & (values <= np.pi)).all()
329
-
330
- values = _normalize_angle(rvs, zero_centered=False)
331
- assert ((values >= 0) & (values <= 2 * np.pi)).all()
332
-
333
-
334
- @pytest.mark.parametrize("mean", [[0, 0], [1, 1]])
335
- @pytest.mark.parametrize(
336
- "cov",
337
- [
338
- np.diag([1, 1]),
339
- np.diag([0.5, 0.5]),
340
- np.diag([0.25, 1]),
341
- np.array([[0.4, 0.2], [0.2, 0.8]]),
342
- ],
343
- )
344
- @pytest.mark.parametrize("contour_sigma", [np.array([1, 2, 3])])
345
- def test_find_hdi_contours(mean, cov, contour_sigma):
346
- """Test `_find_hdi_contours()` against SciPy's multivariate normal distribution."""
347
- # Set up scipy distribution
348
- prob_dist = st.multivariate_normal(mean, cov)
349
-
350
- # Find standard deviations and eigenvectors
351
- eigenvals, eigenvecs = np.linalg.eig(cov)
352
- eigenvecs = eigenvecs.T
353
- stdevs = np.sqrt(eigenvals)
354
-
355
- # Find min and max for grid at 7-sigma contour
356
- extremes = np.empty((4, 2))
357
- for i in range(4):
358
- extremes[i] = mean + (-1) ** i * 7 * stdevs[i // 2] * eigenvecs[i // 2]
359
- x_min, y_min = np.amin(extremes, axis=0)
360
- x_max, y_max = np.amax(extremes, axis=0)
361
-
362
- # Create 256x256 grid
363
- x = np.linspace(x_min, x_max, 256)
364
- y = np.linspace(y_min, y_max, 256)
365
- grid = np.dstack(np.meshgrid(x, y))
366
-
367
- density = prob_dist.pdf(grid)
368
-
369
- contour_sp = np.empty(contour_sigma.shape)
370
- for idx, sigma in enumerate(contour_sigma):
371
- contour_sp[idx] = prob_dist.pdf(mean + sigma * stdevs[0] * eigenvecs[0])
372
-
373
- hdi_probs = 1 - np.exp(-0.5 * contour_sigma**2)
374
- contour_az = _find_hdi_contours(density, hdi_probs)
375
-
376
- np.testing.assert_allclose(contour_sp, contour_az, rtol=1e-2, atol=1e-4)
@@ -1,87 +0,0 @@
1
- # pylint: disable=redefined-outer-name, no-member
2
- """Tests for arviz.utils."""
3
- import importlib
4
- from unittest.mock import Mock
5
-
6
- import numpy as np
7
- import pytest
8
-
9
- from ...stats.stats_utils import stats_variance_2d as svar
10
- from ...utils import Numba, _numba_var, numba_check
11
- from ..helpers import importorskip
12
- from .test_utils import utils_with_numba_import_fail # pylint: disable=unused-import
13
-
14
- importorskip("numba")
15
-
16
-
17
- def test_utils_fixture(utils_with_numba_import_fail):
18
- """Test of utils fixture to ensure mock is applied correctly"""
19
-
20
- # If Numba doesn't exist in dev environment this will raise an ImportError
21
- import numba # pylint: disable=unused-import,W0612
22
-
23
- with pytest.raises(ImportError):
24
- utils_with_numba_import_fail.importlib.import_module("numba")
25
-
26
-
27
- def test_conditional_jit_numba_decorator_keyword(monkeypatch):
28
- """Checks else statement and JIT keyword argument"""
29
- from ... import utils
30
-
31
- # Mock import lib to return numba with hit method which returns a function that returns kwargs
32
- numba_mock = Mock()
33
- monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock)
34
-
35
- def jit(**kwargs):
36
- """overwrite numba.jit function"""
37
- return lambda fn: lambda: (fn(), kwargs)
38
-
39
- numba_mock.jit = jit
40
-
41
- @utils.conditional_jit(keyword_argument="A keyword argument")
42
- def placeholder_func():
43
- """This function does nothing"""
44
- return "output"
45
-
46
- # pylint: disable=unpacking-non-sequence
47
- function_results, wrapper_result = placeholder_func()
48
- assert wrapper_result == {"keyword_argument": "A keyword argument", "nopython": True}
49
- assert function_results == "output"
50
-
51
-
52
- def test_numba_check():
53
- """Test for numba_check"""
54
- numba = importlib.util.find_spec("numba")
55
- flag = numba is not None
56
- assert flag == numba_check()
57
-
58
-
59
- def test_numba_utils():
60
- """Test for class Numba."""
61
- flag = Numba.numba_flag
62
- assert flag == numba_check()
63
- Numba.disable_numba()
64
- val = Numba.numba_flag
65
- assert not val
66
- Numba.enable_numba()
67
- val = Numba.numba_flag
68
- assert val
69
- assert flag == Numba.numba_flag
70
-
71
-
72
- @pytest.mark.parametrize("axis", (0, 1))
73
- @pytest.mark.parametrize("ddof", (0, 1))
74
- def test_numba_var(axis, ddof):
75
- """Method to test numba_var."""
76
- flag = Numba.numba_flag
77
- data_1 = np.random.randn(100, 100)
78
- data_2 = np.random.rand(100)
79
- with_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof)
80
- with_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof)
81
- Numba.disable_numba()
82
- non_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof)
83
- non_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof)
84
- Numba.enable_numba()
85
- assert flag == Numba.numba_flag
86
- assert np.allclose(with_numba_1, non_numba_1)
87
- assert np.allclose(with_numba_2, non_numba_2)
arviz/tests/conftest.py DELETED
@@ -1,46 +0,0 @@
1
- # pylint: disable=redefined-outer-name
2
- """Configuration for test suite."""
3
- import logging
4
- import os
5
-
6
- import numpy as np
7
- import pytest
8
-
9
- _log = logging.getLogger(__name__)
10
-
11
-
12
- @pytest.fixture(autouse=True)
13
- def random_seed():
14
- """Reset numpy random seed generator."""
15
- np.random.seed(0)
16
-
17
-
18
- def pytest_addoption(parser):
19
- """Definition for command line option to save figures from tests."""
20
- parser.addoption("--save", nargs="?", const="test_images", help="Save images rendered by plot")
21
-
22
-
23
- @pytest.fixture(scope="session")
24
- def save_figs(request):
25
- """Enable command line switch for saving generation figures upon testing."""
26
- fig_dir = request.config.getoption("--save")
27
-
28
- if fig_dir is not None:
29
- # Try creating directory if it doesn't exist
30
- _log.info("Saving generated images in %s", fig_dir)
31
-
32
- os.makedirs(fig_dir, exist_ok=True)
33
- _log.info("Directory %s created", fig_dir)
34
-
35
- # Clear all files from the directory
36
- # Does not alter or delete directories
37
- for file in os.listdir(fig_dir):
38
- full_path = os.path.join(fig_dir, file)
39
-
40
- try:
41
- os.remove(full_path)
42
-
43
- except OSError:
44
- _log.info("Failed to remove %s", full_path)
45
-
46
- return fig_dir
@@ -1 +0,0 @@
1
- """Backend test suite."""
@@ -1,78 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name
2
- import numpy as np
3
- import pytest
4
-
5
- from ...data.io_beanmachine import from_beanmachine # pylint: disable=wrong-import-position
6
- from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
7
- chains,
8
- draws,
9
- eight_schools_params,
10
- importorskip,
11
- load_cached_models,
12
- )
13
-
14
- pytest.skip("Ignore beanmachine tests until it supports pytorch 2", allow_module_level=True)
15
-
16
- # Skip all tests if beanmachine or pytorch not installed
17
- torch = importorskip("torch")
18
- bm = importorskip("beanmachine.ppl")
19
- dist = torch.distributions
20
-
21
-
22
- class TestDataBeanMachine:
23
- @pytest.fixture(scope="class")
24
- def data(self, eight_schools_params, draws, chains):
25
- class Data:
26
- model, prior, obj = load_cached_models(
27
- eight_schools_params,
28
- draws,
29
- chains,
30
- "beanmachine",
31
- )["beanmachine"]
32
-
33
- return Data
34
-
35
- @pytest.fixture(scope="class")
36
- def predictions_data(self, data):
37
- """Generate predictions for predictions_params"""
38
- posterior_samples = data.obj
39
- model = data.model
40
- predictions = bm.inference.predictive.simulate([model.obs()], posterior_samples)
41
- return predictions
42
-
43
- def get_inference_data(self, eight_schools_params, predictions_data):
44
- predictions = predictions_data
45
- return from_beanmachine(
46
- sampler=predictions,
47
- coords={
48
- "school": np.arange(eight_schools_params["J"]),
49
- "school_pred": np.arange(eight_schools_params["J"]),
50
- },
51
- )
52
-
53
- def test_inference_data(self, data, eight_schools_params, predictions_data):
54
- inference_data = self.get_inference_data(eight_schools_params, predictions_data)
55
- model = data.model
56
- mu = model.mu()
57
- tau = model.tau()
58
- eta = model.eta()
59
- obs = model.obs()
60
-
61
- assert mu in inference_data.posterior
62
- assert tau in inference_data.posterior
63
- assert eta in inference_data.posterior
64
- assert obs in inference_data.posterior_predictive
65
-
66
- def test_inference_data_has_log_likelihood_and_observed_data(self, data):
67
- idata = from_beanmachine(data.obj)
68
- obs = data.model.obs()
69
-
70
- assert obs in idata.log_likelihood
71
- assert obs in idata.observed_data
72
-
73
- def test_inference_data_no_posterior(self, data):
74
- model = data.model
75
- # only prior
76
- inference_data = from_beanmachine(data.prior)
77
- assert not model.obs() in inference_data.posterior
78
- assert "observed_data" not in inference_data