arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,1679 +0,0 @@
1
- # pylint: disable=no-member, invalid-name, redefined-outer-name
2
- # pylint: disable=too-many-lines
3
-
4
- import importlib
5
- import os
6
- import warnings
7
- from collections import namedtuple
8
- from copy import deepcopy
9
- from html import escape
10
- from typing import Dict
11
- from tempfile import TemporaryDirectory
12
- from urllib.parse import urlunsplit
13
-
14
- import numpy as np
15
- import pytest
16
- import xarray as xr
17
- from xarray.core.options import OPTIONS
18
- from xarray.testing import assert_identical
19
-
20
- from ... import (
21
- InferenceData,
22
- clear_data_home,
23
- concat,
24
- convert_to_dataset,
25
- convert_to_inference_data,
26
- from_datatree,
27
- from_dict,
28
- from_json,
29
- from_netcdf,
30
- list_datasets,
31
- load_arviz_data,
32
- to_netcdf,
33
- extract,
34
- )
35
-
36
- from ...data.base import (
37
- dict_to_dataset,
38
- generate_dims_coords,
39
- infer_stan_dtypes,
40
- make_attrs,
41
- numpy_to_data_array,
42
- )
43
- from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
44
- from ..helpers import ( # pylint: disable=unused-import
45
- chains,
46
- check_multiple_attrs,
47
- create_data_random,
48
- data_random,
49
- draws,
50
- eight_schools_params,
51
- models,
52
- )
53
-
54
- # Check if dm-tree is installed
55
- dm_tree_installed = importlib.util.find_spec("tree") is not None # pylint: disable=invalid-name
56
- skip_tests = (not dm_tree_installed) and ("ARVIZ_REQUIRE_ALL_DEPS" not in os.environ)
57
-
58
-
59
- @pytest.fixture(autouse=True)
60
- def no_remote_data(monkeypatch, tmpdir):
61
- """Delete all remote data and replace it with a local dataset."""
62
- keys = list(REMOTE_DATASETS)
63
- for key in keys:
64
- monkeypatch.delitem(REMOTE_DATASETS, key)
65
-
66
- centered = LOCAL_DATASETS["centered_eight"]
67
- filename = os.path.join(str(tmpdir), os.path.basename(centered.filename))
68
-
69
- url = urlunsplit(("file", "", centered.filename, "", ""))
70
-
71
- monkeypatch.setitem(
72
- REMOTE_DATASETS,
73
- "test_remote",
74
- RemoteFileMetadata(
75
- name="test_remote",
76
- filename=filename,
77
- url=url,
78
- checksum="8efc3abafe0c796eb9aea7b69490d4e2400a33c57504ef4932e1c7105849176f",
79
- description=centered.description,
80
- ),
81
- )
82
- monkeypatch.setitem(
83
- REMOTE_DATASETS,
84
- "bad_checksum",
85
- RemoteFileMetadata(
86
- name="bad_checksum",
87
- filename=filename,
88
- url=url,
89
- checksum="bad!",
90
- description=centered.description,
91
- ),
92
- )
93
- UnknownFileMetaData = namedtuple(
94
- "UnknownFileMetaData", ["filename", "url", "checksum", "description"]
95
- )
96
- monkeypatch.setitem(
97
- REMOTE_DATASETS,
98
- "test_unknown",
99
- UnknownFileMetaData(
100
- filename=filename,
101
- url=url,
102
- checksum="9ae00c83654b3f061d32c882ec0a270d10838fa36515ecb162b89a290e014849",
103
- description="Test bad REMOTE_DATASET",
104
- ),
105
- )
106
-
107
-
108
- def test_load_local_arviz_data():
109
- inference_data = load_arviz_data("centered_eight")
110
- assert isinstance(inference_data, InferenceData)
111
- assert set(inference_data.observed_data.obs.coords["school"].values) == {
112
- "Hotchkiss",
113
- "Mt. Hermon",
114
- "Choate",
115
- "Deerfield",
116
- "Phillips Andover",
117
- "St. Paul's",
118
- "Lawrenceville",
119
- "Phillips Exeter",
120
- }
121
- assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
122
-
123
-
124
- @pytest.mark.parametrize("fill_attrs", [True, False])
125
- def test_local_save(fill_attrs):
126
- inference_data = load_arviz_data("centered_eight")
127
- assert isinstance(inference_data, InferenceData)
128
-
129
- if fill_attrs:
130
- inference_data.attrs["test"] = 1
131
- with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
132
- path = os.path.join(tmp_dir, "test_file.nc")
133
- inference_data.to_netcdf(path)
134
-
135
- inference_data2 = from_netcdf(path)
136
- if fill_attrs:
137
- assert "test" in inference_data2.attrs
138
- assert inference_data2.attrs["test"] == 1
139
- # pylint: disable=protected-access
140
- assert all(group in inference_data2 for group in inference_data._groups_all)
141
- # pylint: enable=protected-access
142
-
143
-
144
- def test_clear_data_home():
145
- resource = REMOTE_DATASETS["test_remote"]
146
- assert not os.path.exists(resource.filename)
147
- load_arviz_data("test_remote")
148
- assert os.path.exists(resource.filename)
149
- clear_data_home(data_home=os.path.dirname(resource.filename))
150
- assert not os.path.exists(resource.filename)
151
-
152
-
153
- def test_load_remote_arviz_data():
154
- assert load_arviz_data("test_remote")
155
-
156
-
157
- def test_bad_checksum():
158
- with pytest.raises(IOError):
159
- load_arviz_data("bad_checksum")
160
-
161
-
162
- def test_missing_dataset():
163
- with pytest.raises(ValueError):
164
- load_arviz_data("does not exist")
165
-
166
-
167
- def test_list_datasets():
168
- dataset_string = list_datasets()
169
- # make sure all the names of the data sets are in the dataset description
170
- for key in (
171
- "centered_eight",
172
- "non_centered_eight",
173
- "test_remote",
174
- "bad_checksum",
175
- "test_unknown",
176
- ):
177
- assert key in dataset_string
178
-
179
-
180
- def test_dims_coords():
181
- shape = 4, 20, 5
182
- var_name = "x"
183
- dims, coords = generate_dims_coords(shape, var_name)
184
- assert "x_dim_0" in dims
185
- assert "x_dim_1" in dims
186
- assert "x_dim_2" in dims
187
- assert len(coords["x_dim_0"]) == 4
188
- assert len(coords["x_dim_1"]) == 20
189
- assert len(coords["x_dim_2"]) == 5
190
-
191
-
192
- @pytest.mark.parametrize(
193
- "in_dims", (["dim1", "dim2"], ["draw", "dim1", "dim2"], ["chain", "draw", "dim1", "dim2"])
194
- )
195
- def test_dims_coords_default_dims(in_dims):
196
- shape = 4, 7
197
- var_name = "x"
198
- dims, coords = generate_dims_coords(
199
- shape,
200
- var_name,
201
- dims=in_dims,
202
- coords={"chain": ["a", "b", "c"]},
203
- default_dims=["chain", "draw"],
204
- )
205
- assert "dim1" in dims
206
- assert "dim2" in dims
207
- assert ("chain" in dims) == ("chain" in in_dims)
208
- assert ("draw" in dims) == ("draw" in in_dims)
209
- assert len(coords["dim1"]) == 4
210
- assert len(coords["dim2"]) == 7
211
- assert len(coords["chain"]) == 3
212
- assert "draw" not in coords
213
-
214
-
215
- def test_dims_coords_extra_dims():
216
- shape = 4, 20
217
- var_name = "x"
218
- with pytest.warns(UserWarning):
219
- dims, coords = generate_dims_coords(shape, var_name, dims=["xx", "xy", "xz"])
220
- assert "xx" in dims
221
- assert "xy" in dims
222
- assert "xz" in dims
223
- assert len(coords["xx"]) == 4
224
- assert len(coords["xy"]) == 20
225
-
226
-
227
- @pytest.mark.parametrize("shape", [(4, 20), (4, 20, 1)])
228
- def test_dims_coords_skip_event_dims(shape):
229
- coords = {"x": np.arange(4), "y": np.arange(20), "z": np.arange(5)}
230
- dims, coords = generate_dims_coords(
231
- shape, "name", dims=["x", "y", "z"], coords=coords, skip_event_dims=True
232
- )
233
- assert "x" in dims
234
- assert "y" in dims
235
- assert "z" not in dims
236
- assert len(coords["x"]) == 4
237
- assert len(coords["y"]) == 20
238
- assert "z" not in coords
239
-
240
-
241
- @pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
242
- def test_numpy_to_data_array_with_dims(dims):
243
- da = numpy_to_data_array(
244
- np.empty((4, 500, 7)),
245
- var_name="a",
246
- dims=dims,
247
- default_dims=["chain", "draw"],
248
- )
249
- assert list(da.dims) == ["chain", "draw", "a_dim_0"]
250
-
251
-
252
- def test_make_attrs():
253
- extra_attrs = {"key": "Value"}
254
- attrs = make_attrs(attrs=extra_attrs)
255
- assert "key" in attrs
256
- assert attrs["key"] == "Value"
257
-
258
-
259
- @pytest.mark.parametrize("copy", [True, False])
260
- @pytest.mark.parametrize("inplace", [True, False])
261
- @pytest.mark.parametrize("sequence", [True, False])
262
- def test_concat_group(copy, inplace, sequence):
263
- idata1 = from_dict(
264
- posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
265
- )
266
- if copy and inplace:
267
- original_idata1_posterior_id = id(idata1.posterior)
268
- idata2 = from_dict(prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)})
269
- idata3 = from_dict(observed_data={"E": np.random.randn(100), "F": np.random.randn(2, 100)})
270
- # basic case
271
- assert concat(idata1, idata2, copy=True, inplace=False) is not None
272
- if sequence:
273
- new_idata = concat((idata1, idata2, idata3), copy=copy, inplace=inplace)
274
- else:
275
- new_idata = concat(idata1, idata2, idata3, copy=copy, inplace=inplace)
276
- if inplace:
277
- assert new_idata is None
278
- new_idata = idata1
279
- assert new_idata is not None
280
- test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"], "observed_data": ["E", "F"]}
281
- fails = check_multiple_attrs(test_dict, new_idata)
282
- assert not fails
283
- if copy:
284
- if inplace:
285
- assert id(new_idata.posterior) == original_idata1_posterior_id
286
- else:
287
- assert id(new_idata.posterior) != id(idata1.posterior)
288
- assert id(new_idata.prior) != id(idata2.prior)
289
- assert id(new_idata.observed_data) != id(idata3.observed_data)
290
- else:
291
- assert id(new_idata.posterior) == id(idata1.posterior)
292
- assert id(new_idata.prior) == id(idata2.prior)
293
- assert id(new_idata.observed_data) == id(idata3.observed_data)
294
-
295
-
296
- @pytest.mark.parametrize("dim", ["chain", "draw"])
297
- @pytest.mark.parametrize("copy", [True, False])
298
- @pytest.mark.parametrize("inplace", [True, False])
299
- @pytest.mark.parametrize("sequence", [True, False])
300
- @pytest.mark.parametrize("reset_dim", [True, False])
301
- def test_concat_dim(dim, copy, inplace, sequence, reset_dim):
302
- idata1 = from_dict(
303
- posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
304
- observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
305
- )
306
- if inplace:
307
- original_idata1_id = id(idata1)
308
- idata2 = from_dict(
309
- posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
310
- observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
311
- )
312
- idata3 = from_dict(
313
- posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)},
314
- observed_data={"C": np.random.randn(100), "D": np.random.randn(2, 100)},
315
- )
316
- # basic case
317
- assert (
318
- concat(idata1, idata2, dim=dim, copy=copy, inplace=False, reset_dim=reset_dim) is not None
319
- )
320
- if sequence:
321
- new_idata = concat(
322
- (idata1, idata2, idata3), copy=copy, dim=dim, inplace=inplace, reset_dim=reset_dim
323
- )
324
- else:
325
- new_idata = concat(
326
- idata1, idata2, idata3, dim=dim, copy=copy, inplace=inplace, reset_dim=reset_dim
327
- )
328
- if inplace:
329
- assert new_idata is None
330
- new_idata = idata1
331
- assert new_idata is not None
332
- test_dict = {"posterior": ["A", "B"], "observed_data": ["C", "D"]}
333
- fails = check_multiple_attrs(test_dict, new_idata)
334
- assert not fails
335
- if inplace:
336
- assert id(new_idata) == original_idata1_id
337
- else:
338
- assert id(new_idata) != id(idata1)
339
- assert getattr(new_idata.posterior, dim).size == 6 if dim == "chain" else 30
340
- if reset_dim:
341
- assert np.all(
342
- getattr(new_idata.posterior, dim).values
343
- == (np.arange(6) if dim == "chain" else np.arange(30))
344
- )
345
-
346
-
347
- @pytest.mark.parametrize("copy", [True, False])
348
- @pytest.mark.parametrize("inplace", [True, False])
349
- @pytest.mark.parametrize("sequence", [True, False])
350
- def test_concat_edgecases(copy, inplace, sequence):
351
- idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)})
352
- empty = concat()
353
- assert empty is not None
354
- if sequence:
355
- new_idata = concat([idata], copy=copy, inplace=inplace)
356
- else:
357
- new_idata = concat(idata, copy=copy, inplace=inplace)
358
- if inplace:
359
- assert new_idata is None
360
- new_idata = idata
361
- else:
362
- assert new_idata is not None
363
- test_dict = {"posterior": ["A", "B"]}
364
- fails = check_multiple_attrs(test_dict, new_idata)
365
- assert not fails
366
- if copy and not inplace:
367
- assert id(new_idata.posterior) != id(idata.posterior)
368
- else:
369
- assert id(new_idata.posterior) == id(idata.posterior)
370
-
371
-
372
- def test_concat_bad():
373
- with pytest.raises(TypeError):
374
- concat("hello", "hello")
375
- idata = from_dict(posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)})
376
- idata2 = from_dict(posterior={"A": np.random.randn(2, 10, 2)})
377
- idata3 = from_dict(prior={"A": np.random.randn(2, 10, 2)})
378
- with pytest.raises(TypeError):
379
- concat(idata, np.array([1, 2, 3, 4, 5]))
380
- with pytest.raises(TypeError):
381
- concat(idata, idata, dim=None)
382
- with pytest.raises(TypeError):
383
- concat(idata, idata2, dim="chain")
384
- with pytest.raises(TypeError):
385
- concat(idata2, idata, dim="chain")
386
- with pytest.raises(TypeError):
387
- concat(idata, idata3, dim="chain")
388
- with pytest.raises(TypeError):
389
- concat(idata3, idata, dim="chain")
390
-
391
-
392
- def test_inference_concat_keeps_all_fields():
393
- """From failures observed in issue #907"""
394
- idata1 = from_dict(posterior={"A": [1, 2, 3, 4]}, sample_stats={"B": [2, 3, 4, 5]})
395
- idata2 = from_dict(prior={"C": [1, 2, 3, 4]}, observed_data={"D": [2, 3, 4, 5]})
396
-
397
- idata_c1 = concat(idata1, idata2)
398
- idata_c2 = concat(idata2, idata1)
399
-
400
- test_dict = {"posterior": ["A"], "sample_stats": ["B"], "prior": ["C"], "observed_data": ["D"]}
401
-
402
- fails_c1 = check_multiple_attrs(test_dict, idata_c1)
403
- assert not fails_c1
404
- fails_c2 = check_multiple_attrs(test_dict, idata_c2)
405
- assert not fails_c2
406
-
407
-
408
- @pytest.mark.parametrize(
409
- "model_code,expected",
410
- [
411
- ("data {int y;} models {y ~ poisson(3);} generated quantities {int X;}", {"X": "int"}),
412
- (
413
- "data {real y;} models {y ~ normal(0,1);} generated quantities {int Y; real G;}",
414
- {"Y": "int"},
415
- ),
416
- ],
417
- )
418
- def test_infer_stan_dtypes(model_code, expected):
419
- """Test different examples for dtypes in Stan models."""
420
- res = infer_stan_dtypes(model_code)
421
- assert res == expected
422
-
423
-
424
- class TestInferenceData: # pylint: disable=too-many-public-methods
425
- def test_addition(self):
426
- idata1 = from_dict(
427
- posterior={"A": np.random.randn(2, 10, 2), "B": np.random.randn(2, 10, 5, 2)}
428
- )
429
- idata2 = from_dict(
430
- prior={"C": np.random.randn(2, 10, 2), "D": np.random.randn(2, 10, 5, 2)}
431
- )
432
- new_idata = idata1 + idata2
433
- assert new_idata is not None
434
- test_dict = {"posterior": ["A", "B"], "prior": ["C", "D"]}
435
- fails = check_multiple_attrs(test_dict, new_idata)
436
- assert not fails
437
-
438
- def test_iter(self, models):
439
- idata = models.model_1
440
- for group in idata:
441
- assert group in idata._groups_all # pylint: disable=protected-access
442
-
443
- def test_groups(self, models):
444
- idata = models.model_1
445
- for group in idata.groups():
446
- assert group in idata._groups_all # pylint: disable=protected-access
447
-
448
- def test_values(self, models):
449
- idata = models.model_1
450
- datasets = idata.values()
451
- for group in idata.groups():
452
- assert group in idata._groups_all # pylint: disable=protected-access
453
- dataset = getattr(idata, group)
454
- assert dataset in datasets
455
-
456
- def test_items(self, models):
457
- idata = models.model_1
458
- for group, dataset in idata.items():
459
- assert group in idata._groups_all # pylint: disable=protected-access
460
- assert dataset.equals(getattr(idata, group))
461
-
462
- @pytest.mark.parametrize("inplace", [True, False])
463
- def test_extend_xr_method(self, data_random, inplace):
464
- idata = data_random
465
- idata_copy = deepcopy(idata)
466
- kwargs = {"groups": "posterior_groups"}
467
- if inplace:
468
- idata_copy.sum(dim="draw", inplace=inplace, **kwargs)
469
- else:
470
- idata2 = idata_copy.sum(dim="draw", inplace=inplace, **kwargs)
471
- assert idata2 is not idata_copy
472
- idata_copy = idata2
473
- assert_identical(idata_copy.posterior, idata.posterior.sum(dim="draw"))
474
- assert_identical(
475
- idata_copy.posterior_predictive, idata.posterior_predictive.sum(dim="draw")
476
- )
477
- assert_identical(idata_copy.observed_data, idata.observed_data)
478
-
479
- @pytest.mark.parametrize("inplace", [False, True])
480
- def test_sel(self, data_random, inplace):
481
- idata = data_random
482
- original_groups = getattr(idata, "_groups")
483
- ndraws = idata.posterior.draw.values.size
484
- kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]}
485
- if inplace:
486
- idata.sel(inplace=inplace, **kwargs)
487
- else:
488
- idata2 = idata.sel(inplace=inplace, **kwargs)
489
- assert idata2 is not idata
490
- idata = idata2
491
- groups = getattr(idata, "_groups")
492
- assert np.all(np.isin(groups, original_groups))
493
- for group in groups:
494
- dataset = getattr(idata, group)
495
- assert "b_dim_0" in dataset.dims
496
- assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"]))
497
- if group != "observed_data":
498
- assert np.all(np.isin(["chain", "draw"], dataset.dims))
499
- assert np.all(dataset.chain.values == np.arange(0, 4, 2))
500
- assert np.all(dataset.draw.values == np.arange(200, ndraws))
501
-
502
- def test_sel_chain_prior(self):
503
- idata = load_arviz_data("centered_eight")
504
- original_groups = getattr(idata, "_groups")
505
- idata_subset = idata.sel(inplace=False, chain_prior=False, chain=[0, 1, 3])
506
- groups = getattr(idata_subset, "_groups")
507
- assert np.all(np.isin(groups, original_groups))
508
- for group in groups:
509
- dataset_subset = getattr(idata_subset, group)
510
- dataset = getattr(idata, group)
511
- if "chain" in dataset.dims:
512
- assert "chain" in dataset_subset.dims
513
- if "prior" not in group:
514
- assert np.all(dataset_subset.chain.values == np.array([0, 1, 3]))
515
- else:
516
- assert "chain" not in dataset_subset.dims
517
- with pytest.raises(KeyError):
518
- idata.sel(inplace=False, chain_prior=True, chain=[0, 1, 3])
519
-
520
- @pytest.mark.parametrize("use", ("del", "delattr", "delitem"))
521
- def test_del(self, use):
522
- # create inference data object
523
- data = np.random.normal(size=(4, 500, 8))
524
- idata = from_dict(
525
- posterior={"a": data[..., 0], "b": data},
526
- sample_stats={"a": data[..., 0], "b": data},
527
- observed_data={"b": data[0, 0, :]},
528
- posterior_predictive={"a": data[..., 0], "b": data},
529
- )
530
-
531
- # assert inference data object has all attributes
532
- test_dict = {
533
- "posterior": ("a", "b"),
534
- "sample_stats": ("a", "b"),
535
- "observed_data": ["b"],
536
- "posterior_predictive": ("a", "b"),
537
- }
538
- fails = check_multiple_attrs(test_dict, idata)
539
- assert not fails
540
- # assert _groups attribute contains all groups
541
- groups = getattr(idata, "_groups")
542
- assert all((group in groups for group in test_dict))
543
-
544
- # Use del method
545
- if use == "del":
546
- del idata.sample_stats
547
- elif use == "delitem":
548
- del idata["sample_stats"]
549
- else:
550
- delattr(idata, "sample_stats")
551
-
552
- # assert attribute has been removed
553
- test_dict.pop("sample_stats")
554
- fails = check_multiple_attrs(test_dict, idata)
555
- assert not fails
556
- assert not hasattr(idata, "sample_stats")
557
- # assert _groups attribute has been updated
558
- assert "sample_stats" not in getattr(idata, "_groups")
559
-
560
- @pytest.mark.parametrize(
561
- "args_res",
562
- (
563
- ([("posterior", "sample_stats")], ("posterior", "sample_stats")),
564
- (["posterior", "like"], ("posterior", "warmup_posterior", "posterior_predictive")),
565
- (["^posterior", "regex"], ("posterior", "posterior_predictive")),
566
- (
567
- [("~^warmup", "~^obs"), "regex"],
568
- ("posterior", "sample_stats", "posterior_predictive"),
569
- ),
570
- (
571
- ["~observed_vars"],
572
- ("posterior", "sample_stats", "warmup_posterior", "warmup_sample_stats"),
573
- ),
574
- ),
575
- )
576
- def test_group_names(self, args_res):
577
- args, result = args_res
578
- ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
579
- idata = InferenceData(
580
- posterior=(ds, ds),
581
- sample_stats=(ds, ds),
582
- observed_data=ds,
583
- posterior_predictive=ds,
584
- )
585
- group_names = idata._group_names(*args) # pylint: disable=protected-access
586
- assert np.all([name in result for name in group_names])
587
-
588
- def test_group_names_invalid_args(self):
589
- ds = dict_to_dataset({"a": np.random.normal(size=(3, 10))})
590
- idata = InferenceData(posterior=(ds, ds))
591
- msg = r"^\'filter_groups\' can only be None, \'like\', or \'regex\', got: 'foo'$"
592
- with pytest.raises(ValueError, match=msg):
593
- idata._group_names( # pylint: disable=protected-access
594
- ("posterior",), filter_groups="foo"
595
- )
596
-
597
- @pytest.mark.parametrize("inplace", [False, True])
598
- def test_isel(self, data_random, inplace):
599
- idata = data_random
600
- original_groups = getattr(idata, "_groups")
601
- ndraws = idata.posterior.draw.values.size
602
- kwargs = {"draw": slice(200, None), "chain": slice(None, None, 2), "b_dim_0": [1, 2, 7]}
603
- if inplace:
604
- idata.isel(inplace=inplace, **kwargs)
605
- else:
606
- idata2 = idata.isel(inplace=inplace, **kwargs)
607
- assert idata2 is not idata
608
- idata = idata2
609
- groups = getattr(idata, "_groups")
610
- assert np.all(np.isin(groups, original_groups))
611
- for group in groups:
612
- dataset = getattr(idata, group)
613
- assert "b_dim_0" in dataset.dims
614
- assert np.all(dataset.b_dim_0.values == np.array(kwargs["b_dim_0"]))
615
- if group != "observed_data":
616
- assert np.all(np.isin(["chain", "draw"], dataset.dims))
617
- assert np.all(dataset.chain.values == np.arange(0, 4, 2))
618
- assert np.all(dataset.draw.values == np.arange(200, ndraws))
619
-
620
- def test_rename(self, data_random):
621
- idata = data_random
622
- original_groups = getattr(idata, "_groups")
623
- renamed_idata = idata.rename({"b": "b_new"})
624
- for group in original_groups:
625
- xr_data = getattr(renamed_idata, group)
626
- assert "b_new" in list(xr_data.data_vars)
627
- assert "b" not in list(xr_data.data_vars)
628
-
629
- renamed_idata = idata.rename({"b_dim_0": "b_new"})
630
- for group in original_groups:
631
- xr_data = getattr(renamed_idata, group)
632
- assert "b_new" in list(xr_data.dims)
633
- assert "b_dim_0" not in list(xr_data.dims)
634
-
635
- def test_rename_vars(self, data_random):
636
- idata = data_random
637
- original_groups = getattr(idata, "_groups")
638
- renamed_idata = idata.rename_vars({"b": "b_new"})
639
- for group in original_groups:
640
- xr_data = getattr(renamed_idata, group)
641
- assert "b_new" in list(xr_data.data_vars)
642
- assert "b" not in list(xr_data.data_vars)
643
-
644
- renamed_idata = idata.rename_vars({"b_dim_0": "b_new"})
645
- for group in original_groups:
646
- xr_data = getattr(renamed_idata, group)
647
- assert "b_new" not in list(xr_data.dims)
648
- assert "b_dim_0" in list(xr_data.dims)
649
-
650
- def test_rename_dims(self, data_random):
651
- idata = data_random
652
- original_groups = getattr(idata, "_groups")
653
- renamed_idata = idata.rename_dims({"b_dim_0": "b_new"})
654
- for group in original_groups:
655
- xr_data = getattr(renamed_idata, group)
656
- assert "b_new" in list(xr_data.dims)
657
- assert "b_dim_0" not in list(xr_data.dims)
658
-
659
- renamed_idata = idata.rename_dims({"b": "b_new"})
660
- for group in original_groups:
661
- xr_data = getattr(renamed_idata, group)
662
- assert "b_new" not in list(xr_data.data_vars)
663
- assert "b" in list(xr_data.data_vars)
664
-
665
- def test_stack_unstack(self):
666
- datadict = {
667
- "a": np.random.randn(100),
668
- "b": np.random.randn(1, 100, 10),
669
- "c": np.random.randn(1, 100, 3, 4),
670
- }
671
- coords = {
672
- "c1": np.arange(3),
673
- "c99": np.arange(4),
674
- "b1": np.arange(10),
675
- }
676
- dims = {"c": ["c1", "c99"], "b": ["b1"]}
677
- dataset = from_dict(posterior=datadict, coords=coords, dims=dims)
678
- assert_identical(
679
- dataset.stack(z=["c1", "c99"]).posterior, dataset.posterior.stack(z=["c1", "c99"])
680
- )
681
- assert_identical(dataset.stack(z=["c1", "c99"]).unstack().posterior, dataset.posterior)
682
- assert_identical(
683
- dataset.stack(z=["c1", "c99"]).unstack(dim="z").posterior, dataset.posterior
684
- )
685
-
686
- def test_stack_bool(self):
687
- datadict = {
688
- "a": np.random.randn(100),
689
- "b": np.random.randn(1, 100, 10),
690
- "c": np.random.randn(1, 100, 3, 4),
691
- }
692
- coords = {
693
- "c1": np.arange(3),
694
- "c99": np.arange(4),
695
- "b1": np.arange(10),
696
- }
697
- dims = {"c": ["c1", "c99"], "b": ["b1"]}
698
- dataset = from_dict(posterior=datadict, coords=coords, dims=dims)
699
- assert_identical(
700
- dataset.stack(z=["c1", "c99"], create_index=False).posterior,
701
- dataset.posterior.stack(z=["c1", "c99"], create_index=False),
702
- )
703
-
704
- def test_to_dict(self, models):
705
- idata = models.model_1
706
- test_data = from_dict(**idata.to_dict())
707
- assert test_data
708
- for group in idata._groups_all: # pylint: disable=protected-access
709
- xr_data = getattr(idata, group)
710
- test_xr_data = getattr(test_data, group)
711
- assert xr_data.equals(test_xr_data)
712
-
713
- def test_to_dict_warmup(self):
714
- idata = create_data_random(
715
- groups=[
716
- "posterior",
717
- "sample_stats",
718
- "observed_data",
719
- "warmup_posterior",
720
- "warmup_posterior_predictive",
721
- ]
722
- )
723
- test_data = from_dict(**idata.to_dict(), save_warmup=True)
724
- assert test_data
725
- for group in idata._groups_all: # pylint: disable=protected-access
726
- xr_data = getattr(idata, group)
727
- test_xr_data = getattr(test_data, group)
728
- assert xr_data.equals(test_xr_data)
729
-
730
- @pytest.mark.parametrize(
731
- "kwargs",
732
- (
733
- {
734
- "groups": "posterior",
735
- "include_coords": True,
736
- "include_index": True,
737
- "index_origin": 0,
738
- },
739
- {
740
- "groups": ["posterior", "sample_stats"],
741
- "include_coords": False,
742
- "include_index": True,
743
- "index_origin": 0,
744
- },
745
- {
746
- "groups": "posterior_groups",
747
- "include_coords": True,
748
- "include_index": False,
749
- "index_origin": 1,
750
- },
751
- ),
752
- )
753
- def test_to_dataframe(self, kwargs):
754
- idata = from_dict(
755
- posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
756
- sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
757
- observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
758
- )
759
- test_data = idata.to_dataframe(**kwargs)
760
- assert not test_data.empty
761
- groups = kwargs.get("groups", idata._groups_all) # pylint: disable=protected-access
762
- for group in idata._groups_all: # pylint: disable=protected-access
763
- if "data" in group:
764
- continue
765
- assert test_data.shape == (
766
- (4 * 100, 3 * 4 * 5 + 1 + 2)
767
- if groups == "posterior"
768
- else (4 * 100, (3 * 4 * 5 + 1) * 2 + 2)
769
- )
770
- if groups == "posterior":
771
- if kwargs.get("include_coords", True) and kwargs.get("include_index", True):
772
- assert any(
773
- f"[{kwargs.get('index_origin', 0)}," in item[0]
774
- for item in test_data.columns
775
- if isinstance(item, tuple)
776
- )
777
- if kwargs.get("include_coords", True):
778
- assert any(isinstance(item, tuple) for item in test_data.columns)
779
- else:
780
- assert not any(isinstance(item, tuple) for item in test_data.columns)
781
- else:
782
- if not kwargs.get("include_index", True):
783
- assert all(
784
- item in test_data.columns
785
- for item in (("posterior", "a", 1, 1, 1), ("posterior", "b"))
786
- )
787
- assert all(item in test_data.columns for item in ("chain", "draw"))
788
-
789
- @pytest.mark.parametrize(
790
- "kwargs",
791
- (
792
- {
793
- "var_names": ["parameter_1", "parameter_2", "variable_1", "variable_2"],
794
- "filter_vars": None,
795
- "var_results": [
796
- ("posterior", "parameter_1"),
797
- ("posterior", "parameter_2"),
798
- ("prior", "parameter_1"),
799
- ("prior", "parameter_2"),
800
- ("posterior", "variable_1"),
801
- ("posterior", "variable_2"),
802
- ],
803
- },
804
- {
805
- "var_names": "parameter",
806
- "filter_vars": "like",
807
- "groups": "posterior",
808
- "var_results": ["parameter_1", "parameter_2"],
809
- },
810
- {
811
- "var_names": "~parameter",
812
- "filter_vars": "like",
813
- "groups": "posterior",
814
- "var_results": ["variable_1", "variable_2", "custom_name"],
815
- },
816
- {
817
- "var_names": [".+_2$", "custom_name"],
818
- "filter_vars": "regex",
819
- "groups": "posterior",
820
- "var_results": ["parameter_2", "variable_2", "custom_name"],
821
- },
822
- {
823
- "var_names": ["lp"],
824
- "filter_vars": "regex",
825
- "groups": "sample_stats",
826
- "var_results": ["lp"],
827
- },
828
- ),
829
- )
830
- def test_to_dataframe_selection(self, kwargs):
831
- results = kwargs.pop("var_results")
832
- idata = from_dict(
833
- posterior={
834
- "parameter_1": np.random.randn(4, 100),
835
- "parameter_2": np.random.randn(4, 100),
836
- "variable_1": np.random.randn(4, 100),
837
- "variable_2": np.random.randn(4, 100),
838
- "custom_name": np.random.randn(4, 100),
839
- },
840
- prior={
841
- "parameter_1": np.random.randn(4, 100),
842
- "parameter_2": np.random.randn(4, 100),
843
- },
844
- sample_stats={
845
- "lp": np.random.randn(4, 100),
846
- },
847
- )
848
- test_data = idata.to_dataframe(**kwargs)
849
- assert not test_data.empty
850
- assert set(test_data.columns).symmetric_difference(results) == set(["chain", "draw"])
851
-
852
- def test_to_dataframe_bad(self):
853
- idata = from_dict(
854
- posterior={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
855
- sample_stats={"a": np.random.randn(4, 100, 3, 4, 5), "b": np.random.randn(4, 100)},
856
- observed_data={"a": np.random.randn(3, 4, 5), "b": np.random.randn(4)},
857
- )
858
- with pytest.raises(TypeError):
859
- idata.to_dataframe(index_origin=2)
860
-
861
- with pytest.raises(TypeError):
862
- idata.to_dataframe(include_coords=False, include_index=False)
863
-
864
- with pytest.raises(TypeError):
865
- idata.to_dataframe(groups=["observed_data"])
866
-
867
- with pytest.raises(KeyError):
868
- idata.to_dataframe(groups=["invalid_group"])
869
-
870
- with pytest.raises(ValueError):
871
- idata.to_dataframe(var_names=["c"])
872
-
873
- @pytest.mark.parametrize("use", (None, "args", "kwargs"))
874
- def test_map(self, use):
875
- idata = load_arviz_data("centered_eight")
876
- args = []
877
- kwargs = {}
878
- if use is None:
879
- fun = lambda x: x + 3
880
- elif use == "args":
881
- fun = lambda x, a: x + a
882
- args = [3]
883
- else:
884
- fun = lambda x, a: x + a
885
- kwargs = {"a": 3}
886
- groups = ("observed_data", "posterior_predictive")
887
- idata_map = idata.map(fun, groups, args=args, **kwargs)
888
- groups_map = idata_map._groups # pylint: disable=protected-access
889
- assert groups_map == idata._groups # pylint: disable=protected-access
890
- assert np.allclose(
891
- idata_map.observed_data.obs, fun(idata.observed_data.obs, *args, **kwargs)
892
- )
893
- assert np.allclose(
894
- idata_map.posterior_predictive.obs, fun(idata.posterior_predictive.obs, *args, **kwargs)
895
- )
896
- assert np.allclose(idata_map.posterior.mu, idata.posterior.mu)
897
-
898
- def test_repr_html(self):
899
- """Test if the function _repr_html is generating html."""
900
- idata = load_arviz_data("centered_eight")
901
- display_style = OPTIONS["display_style"]
902
- xr.set_options(display_style="html")
903
- html = idata._repr_html_() # pylint: disable=protected-access
904
-
905
- assert html is not None
906
- assert "<div" in html
907
- for group in idata._groups: # pylint: disable=protected-access
908
- assert group in html
909
- xr_data = getattr(idata, group)
910
- for item, _ in xr_data.items():
911
- assert item in html
912
- specific_style = ".xr-wrap{width:700px!important;}"
913
- assert specific_style in html
914
-
915
- xr.set_options(display_style="text")
916
- html = idata._repr_html_() # pylint: disable=protected-access
917
- assert escape(repr(idata)) in html
918
- xr.set_options(display_style=display_style)
919
-
920
- def test_setitem(self, data_random):
921
- data_random["new_group"] = data_random.posterior
922
- assert "new_group" in data_random.groups()
923
- assert hasattr(data_random, "new_group")
924
-
925
- def test_add_groups(self, data_random):
926
- data = np.random.normal(size=(4, 500, 8))
927
- idata = data_random
928
- idata.add_groups({"prior": {"a": data[..., 0], "b": data}})
929
- assert "prior" in idata._groups # pylint: disable=protected-access
930
- assert isinstance(idata.prior, xr.Dataset)
931
- assert hasattr(idata, "prior")
932
-
933
- idata.add_groups(warmup_posterior={"a": data[..., 0], "b": data})
934
- assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access
935
- assert isinstance(idata.warmup_posterior, xr.Dataset)
936
- assert hasattr(idata, "warmup_posterior")
937
-
938
- def test_add_groups_warning(self, data_random):
939
- data = np.random.normal(size=(4, 500, 8))
940
- idata = data_random
941
- with pytest.warns(UserWarning, match="The group.+not defined in the InferenceData scheme"):
942
- idata.add_groups({"new_group": idata.posterior}, warn_on_custom_groups=True)
943
- with pytest.warns(UserWarning, match="the default dims.+will be added automatically"):
944
- idata.add_groups(constant_data={"a": data[..., 0], "b": data})
945
- assert idata.new_group.equals(idata.posterior)
946
-
947
- def test_add_groups_error(self, data_random):
948
- idata = data_random
949
- with pytest.raises(ValueError, match="One of.+must be provided."):
950
- idata.add_groups()
951
- with pytest.raises(ValueError, match="Arguments.+xr.Dataset, xr.Dataarray or dicts"):
952
- idata.add_groups({"new_group": "new_group"})
953
- with pytest.raises(ValueError, match="group.+already exists"):
954
- idata.add_groups({"posterior": idata.posterior})
955
-
956
- def test_extend(self, data_random):
957
- idata = data_random
958
- idata2 = create_data_random(
959
- groups=["prior", "prior_predictive", "observed_data", "warmup_posterior"], seed=7
960
- )
961
- idata.extend(idata2)
962
- assert "prior" in idata._groups_all # pylint: disable=protected-access
963
- assert "warmup_posterior" in idata._groups_all # pylint: disable=protected-access
964
- assert hasattr(idata, "prior")
965
- assert hasattr(idata, "prior_predictive")
966
- assert idata.prior.equals(idata2.prior)
967
- assert not idata.observed_data.equals(idata2.observed_data)
968
- assert idata.prior_predictive.equals(idata2.prior_predictive)
969
-
970
- idata.extend(idata2, join="right")
971
- assert idata.prior.equals(idata2.prior)
972
- assert idata.observed_data.equals(idata2.observed_data)
973
- assert idata.prior_predictive.equals(idata2.prior_predictive)
974
-
975
- def test_extend_errors_warnings(self, data_random):
976
- idata = data_random
977
- idata2 = create_data_random(groups=["prior", "prior_predictive", "observed_data"], seed=7)
978
- with pytest.raises(ValueError, match="Extending.+InferenceData objects only."):
979
- idata.extend("something")
980
- with pytest.raises(ValueError, match="join must be either"):
981
- idata.extend(idata2, join="outer")
982
- idata2.add_groups(new_group=idata2.prior)
983
- with pytest.warns(UserWarning, match="new_group"):
984
- idata.extend(idata2, warn_on_custom_groups=True)
985
-
986
-
987
- class TestNumpyToDataArray:
988
- def test_1d_dataset(self):
989
- size = 100
990
- dataset = convert_to_dataset(np.random.randn(size))
991
- assert len(dataset.data_vars) == 1
992
-
993
- assert set(dataset.coords) == {"chain", "draw"}
994
- assert dataset.chain.shape == (1,)
995
- assert dataset.draw.shape == (size,)
996
-
997
- def test_warns_bad_shape(self):
998
- # Shape should be (chain, draw, *shape)
999
- with pytest.warns(UserWarning):
1000
- convert_to_dataset(np.random.randn(100, 4))
1001
-
1002
- def test_nd_to_dataset(self):
1003
- shape = (1, 2, 3, 4, 5)
1004
- dataset = convert_to_dataset(np.random.randn(*shape))
1005
- assert len(dataset.data_vars) == 1
1006
- var_name = list(dataset.data_vars)[0]
1007
-
1008
- assert len(dataset.coords) == len(shape)
1009
- assert dataset.chain.shape == shape[:1]
1010
- assert dataset.draw.shape == shape[1:2]
1011
- assert dataset[var_name].shape == shape
1012
-
1013
- def test_nd_to_inference_data(self):
1014
- shape = (1, 2, 3, 4, 5)
1015
- inference_data = convert_to_inference_data(np.random.randn(*shape), group="prior")
1016
- assert hasattr(inference_data, "prior")
1017
- assert len(inference_data.prior.data_vars) == 1
1018
- var_name = list(inference_data.prior.data_vars)[0]
1019
-
1020
- assert len(inference_data.prior.coords) == len(shape)
1021
- assert inference_data.prior.chain.shape == shape[:1]
1022
- assert inference_data.prior.draw.shape == shape[1:2]
1023
- assert inference_data.prior[var_name].shape == shape
1024
- assert repr(inference_data).startswith("Inference data with groups")
1025
-
1026
- def test_more_chains_than_draws(self):
1027
- shape = (10, 4)
1028
- with pytest.warns(UserWarning):
1029
- inference_data = convert_to_inference_data(np.random.randn(*shape), group="prior")
1030
- assert hasattr(inference_data, "prior")
1031
- assert len(inference_data.prior.data_vars) == 1
1032
- var_name = list(inference_data.prior.data_vars)[0]
1033
-
1034
- assert len(inference_data.prior.coords) == len(shape)
1035
- assert inference_data.prior.chain.shape == shape[:1]
1036
- assert inference_data.prior.draw.shape == shape[1:2]
1037
- assert inference_data.prior[var_name].shape == shape
1038
-
1039
-
1040
- class TestConvertToDataset:
1041
- @pytest.fixture(scope="class")
1042
- def data(self):
1043
- # pylint: disable=attribute-defined-outside-init
1044
- class Data:
1045
- datadict = {
1046
- "a": np.random.randn(100),
1047
- "b": np.random.randn(1, 100, 10),
1048
- "c": np.random.randn(1, 100, 3, 4),
1049
- }
1050
- coords = {"c1": np.arange(3), "c2": np.arange(4), "b1": np.arange(10)}
1051
- dims = {"b": ["b1"], "c": ["c1", "c2"]}
1052
-
1053
- return Data
1054
-
1055
- def test_use_all(self, data):
1056
- dataset = convert_to_dataset(data.datadict, coords=data.coords, dims=data.dims)
1057
- assert set(dataset.data_vars) == {"a", "b", "c"}
1058
- assert set(dataset.coords) == {"chain", "draw", "c1", "c2", "b1"}
1059
-
1060
- assert set(dataset.a.coords) == {"chain", "draw"}
1061
- assert set(dataset.b.coords) == {"chain", "draw", "b1"}
1062
- assert set(dataset.c.coords) == {"chain", "draw", "c1", "c2"}
1063
-
1064
- def test_missing_coords(self, data):
1065
- dataset = convert_to_dataset(data.datadict, coords=None, dims=data.dims)
1066
- assert set(dataset.data_vars) == {"a", "b", "c"}
1067
- assert set(dataset.coords) == {"chain", "draw", "c1", "c2", "b1"}
1068
-
1069
- assert set(dataset.a.coords) == {"chain", "draw"}
1070
- assert set(dataset.b.coords) == {"chain", "draw", "b1"}
1071
- assert set(dataset.c.coords) == {"chain", "draw", "c1", "c2"}
1072
-
1073
- def test_missing_dims(self, data):
1074
- # missing dims
1075
- coords = {"c_dim_0": np.arange(3), "c_dim_1": np.arange(4), "b_dim_0": np.arange(10)}
1076
- dataset = convert_to_dataset(data.datadict, coords=coords, dims=None)
1077
- assert set(dataset.data_vars) == {"a", "b", "c"}
1078
- assert set(dataset.coords) == {"chain", "draw", "c_dim_0", "c_dim_1", "b_dim_0"}
1079
-
1080
- assert set(dataset.a.coords) == {"chain", "draw"}
1081
- assert set(dataset.b.coords) == {"chain", "draw", "b_dim_0"}
1082
- assert set(dataset.c.coords) == {"chain", "draw", "c_dim_0", "c_dim_1"}
1083
-
1084
- def test_skip_dim_0(self, data):
1085
- dims = {"c": [None, "c2"]}
1086
- coords = {"c_dim_0": np.arange(3), "c2": np.arange(4), "b_dim_0": np.arange(10)}
1087
- dataset = convert_to_dataset(data.datadict, coords=coords, dims=dims)
1088
- assert set(dataset.data_vars) == {"a", "b", "c"}
1089
- assert set(dataset.coords) == {"chain", "draw", "c_dim_0", "c2", "b_dim_0"}
1090
-
1091
- assert set(dataset.a.coords) == {"chain", "draw"}
1092
- assert set(dataset.b.coords) == {"chain", "draw", "b_dim_0"}
1093
- assert set(dataset.c.coords) == {"chain", "draw", "c_dim_0", "c2"}
1094
-
1095
-
1096
- def test_dict_to_dataset():
1097
- datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
1098
- dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
1099
- assert set(dataset.data_vars) == {"a", "b"}
1100
- assert set(dataset.coords) == {"chain", "draw", "c"}
1101
-
1102
- assert set(dataset.a.coords) == {"chain", "draw"}
1103
- assert set(dataset.b.coords) == {"chain", "draw", "c"}
1104
-
1105
-
1106
- @pytest.mark.skipif(skip_tests, reason="test requires dm-tree which is not installed")
1107
- def test_nested_dict_to_dataset():
1108
- datadict = {
1109
- "top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
1110
- "d": np.random.randn(100),
1111
- }
1112
- dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
1113
- assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
1114
- assert set(dataset.coords) == {"chain", "draw", "c"}
1115
-
1116
- assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
1117
- assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
1118
- assert set(dataset.d.coords) == {"chain", "draw"}
1119
-
1120
-
1121
- def test_dict_to_dataset_event_dims_error():
1122
- datadict = {"a": np.random.randn(1, 100, 10)}
1123
- coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
1124
- msg = "different number of dimensions on data and dims"
1125
- with pytest.raises(ValueError, match=msg):
1126
- convert_to_dataset(datadict, coords=coords, dims={"a": ["b", "c"]})
1127
-
1128
-
1129
- def test_dict_to_dataset_with_tuple_coord():
1130
- datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
1131
- dataset = convert_to_dataset(datadict, coords={"c": tuple(range(10))}, dims={"b": ["c"]})
1132
- assert set(dataset.data_vars) == {"a", "b"}
1133
- assert set(dataset.coords) == {"chain", "draw", "c"}
1134
-
1135
- assert set(dataset.a.coords) == {"chain", "draw"}
1136
- assert set(dataset.b.coords) == {"chain", "draw", "c"}
1137
-
1138
-
1139
- def test_convert_to_dataset_idempotent():
1140
- first = convert_to_dataset(np.random.randn(100))
1141
- second = convert_to_dataset(first)
1142
- assert first.equals(second)
1143
-
1144
-
1145
- def test_convert_to_inference_data_idempotent():
1146
- first = convert_to_inference_data(np.random.randn(100), group="prior")
1147
- second = convert_to_inference_data(first)
1148
- assert first.prior is second.prior
1149
-
1150
-
1151
- def test_convert_to_inference_data_from_file(tmpdir):
1152
- first = convert_to_inference_data(np.random.randn(100), group="prior")
1153
- filename = str(tmpdir.join("test_file.nc"))
1154
- first.to_netcdf(filename)
1155
- second = convert_to_inference_data(filename)
1156
- assert first.prior.equals(second.prior)
1157
-
1158
-
1159
- def test_convert_to_inference_data_bad():
1160
- with pytest.raises(ValueError):
1161
- convert_to_inference_data(1)
1162
-
1163
-
1164
- def test_convert_to_dataset_bad(tmpdir):
1165
- first = convert_to_inference_data(np.random.randn(100), group="prior")
1166
- filename = str(tmpdir.join("test_file.nc"))
1167
- first.to_netcdf(filename)
1168
- with pytest.raises(ValueError):
1169
- convert_to_dataset(filename, group="bar")
1170
-
1171
-
1172
- def test_bad_inference_data():
1173
- with pytest.raises(ValueError):
1174
- InferenceData(posterior=[1, 2, 3])
1175
-
1176
-
1177
- @pytest.mark.parametrize("warn", [True, False])
1178
- def test_inference_data_other_groups(warn):
1179
- datadict = {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)}
1180
- dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={"b": ["c"]})
1181
- if warn:
1182
- with pytest.warns(UserWarning, match="not.+in.+InferenceData scheme"):
1183
- idata = InferenceData(other_group=dataset, warn_on_custom_groups=True)
1184
- else:
1185
- with warnings.catch_warnings():
1186
- warnings.simplefilter("error")
1187
- idata = InferenceData(other_group=dataset, warn_on_custom_groups=False)
1188
- fails = check_multiple_attrs({"other_group": ["a", "b"]}, idata)
1189
- assert not fails
1190
-
1191
-
1192
- class TestDataConvert:
1193
- @pytest.fixture(scope="class")
1194
- def data(self, draws, chains):
1195
- class Data:
1196
- # fake 8-school output
1197
- obj = {}
1198
- for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
1199
- obj[key] = np.random.randn(chains, draws, *shape)
1200
-
1201
- return Data
1202
-
1203
- def get_inference_data(self, data):
1204
- return convert_to_inference_data(
1205
- data.obj,
1206
- group="posterior",
1207
- coords={"school": np.arange(8)},
1208
- dims={"theta": ["school"], "eta": ["school"]},
1209
- )
1210
-
1211
- def check_var_names_coords_dims(self, dataset):
1212
- assert set(dataset.data_vars) == {"mu", "tau", "eta", "theta"}
1213
- assert set(dataset.coords) == {"chain", "draw", "school"}
1214
-
1215
- def test_convert_to_inference_data(self, data):
1216
- inference_data = self.get_inference_data(data)
1217
- assert hasattr(inference_data, "posterior")
1218
- self.check_var_names_coords_dims(inference_data.posterior)
1219
-
1220
- def test_convert_to_dataset(self, draws, chains, data):
1221
- dataset = convert_to_dataset(
1222
- data.obj,
1223
- group="posterior",
1224
- coords={"school": np.arange(8)},
1225
- dims={"theta": ["school"], "eta": ["school"]},
1226
- )
1227
- assert dataset.draw.shape == (draws,)
1228
- assert dataset.chain.shape == (chains,)
1229
- assert dataset.school.shape == (8,)
1230
- assert dataset.theta.shape == (chains, draws, 8)
1231
-
1232
-
1233
- class TestDataDict:
1234
- @pytest.fixture(scope="class")
1235
- def data(self, draws, chains):
1236
- class Data:
1237
- # fake 8-school output
1238
- obj = {}
1239
- for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
1240
- obj[key] = np.random.randn(chains, draws, *shape)
1241
-
1242
- return Data
1243
-
1244
- def check_var_names_coords_dims(self, dataset):
1245
- assert set(dataset.data_vars) == {"mu", "tau", "eta", "theta"}
1246
- assert set(dataset.coords) == {"chain", "draw", "school"}
1247
-
1248
- def get_inference_data(self, data, eight_schools_params, save_warmup=False):
1249
- return from_dict(
1250
- posterior=data.obj,
1251
- posterior_predictive=data.obj,
1252
- sample_stats=data.obj,
1253
- prior=data.obj,
1254
- prior_predictive=data.obj,
1255
- sample_stats_prior=data.obj,
1256
- warmup_posterior=data.obj,
1257
- warmup_posterior_predictive=data.obj,
1258
- predictions=data.obj,
1259
- observed_data=eight_schools_params,
1260
- coords={
1261
- "school": np.arange(8),
1262
- },
1263
- pred_coords={
1264
- "school_pred": np.arange(8),
1265
- },
1266
- dims={"theta": ["school"], "eta": ["school"]},
1267
- pred_dims={"theta": ["school_pred"], "eta": ["school_pred"]},
1268
- save_warmup=save_warmup,
1269
- )
1270
-
1271
- def test_inference_data(self, data, eight_schools_params):
1272
- inference_data = self.get_inference_data(data, eight_schools_params)
1273
- test_dict = {
1274
- "posterior": [],
1275
- "prior": [],
1276
- "sample_stats": [],
1277
- "posterior_predictive": [],
1278
- "prior_predictive": [],
1279
- "sample_stats_prior": [],
1280
- "observed_data": ["J", "y", "sigma"],
1281
- }
1282
- fails = check_multiple_attrs(test_dict, inference_data)
1283
- assert not fails
1284
- self.check_var_names_coords_dims(inference_data.posterior)
1285
- self.check_var_names_coords_dims(inference_data.posterior_predictive)
1286
- self.check_var_names_coords_dims(inference_data.sample_stats)
1287
- self.check_var_names_coords_dims(inference_data.prior)
1288
- self.check_var_names_coords_dims(inference_data.prior_predictive)
1289
- self.check_var_names_coords_dims(inference_data.sample_stats_prior)
1290
-
1291
- pred_dims = inference_data.predictions.sizes["school_pred"]
1292
- assert pred_dims == 8
1293
-
1294
- def test_inference_data_warmup(self, data, eight_schools_params):
1295
- inference_data = self.get_inference_data(data, eight_schools_params, save_warmup=True)
1296
- test_dict = {
1297
- "posterior": [],
1298
- "prior": [],
1299
- "sample_stats": [],
1300
- "posterior_predictive": [],
1301
- "prior_predictive": [],
1302
- "sample_stats_prior": [],
1303
- "observed_data": ["J", "y", "sigma"],
1304
- "warmup_posterior_predictive": [],
1305
- "warmup_posterior": [],
1306
- }
1307
- fails = check_multiple_attrs(test_dict, inference_data)
1308
- assert not fails
1309
- self.check_var_names_coords_dims(inference_data.posterior)
1310
- self.check_var_names_coords_dims(inference_data.posterior_predictive)
1311
- self.check_var_names_coords_dims(inference_data.sample_stats)
1312
- self.check_var_names_coords_dims(inference_data.prior)
1313
- self.check_var_names_coords_dims(inference_data.prior_predictive)
1314
- self.check_var_names_coords_dims(inference_data.sample_stats_prior)
1315
- self.check_var_names_coords_dims(inference_data.warmup_posterior)
1316
- self.check_var_names_coords_dims(inference_data.warmup_posterior_predictive)
1317
-
1318
- def test_inference_data_edge_cases(self):
1319
- # create data
1320
- log_likelihood = {
1321
- "y": np.random.randn(4, 100),
1322
- "log_likelihood": np.random.randn(4, 100, 8),
1323
- }
1324
-
1325
- # log_likelihood to posterior
1326
- with pytest.warns(UserWarning, match="log_likelihood.+in posterior"):
1327
- assert from_dict(posterior=log_likelihood) is not None
1328
-
1329
- # dims == None
1330
- assert from_dict(observed_data=log_likelihood, dims=None) is not None
1331
-
1332
- def test_inference_data_bad(self):
1333
- # create data
1334
- x = np.random.randn(4, 100)
1335
-
1336
- # input ndarray
1337
- with pytest.raises(TypeError):
1338
- from_dict(posterior=x)
1339
- with pytest.raises(TypeError):
1340
- from_dict(posterior_predictive=x)
1341
- with pytest.raises(TypeError):
1342
- from_dict(sample_stats=x)
1343
- with pytest.raises(TypeError):
1344
- from_dict(prior=x)
1345
- with pytest.raises(TypeError):
1346
- from_dict(prior_predictive=x)
1347
- with pytest.raises(TypeError):
1348
- from_dict(sample_stats_prior=x)
1349
- with pytest.raises(TypeError):
1350
- from_dict(observed_data=x)
1351
-
1352
- def test_from_dict_warning(self):
1353
- bad_posterior_dict = {"log_likelihood": np.ones((5, 1000, 2))}
1354
- with pytest.warns(UserWarning):
1355
- from_dict(posterior=bad_posterior_dict)
1356
-
1357
-
1358
- class TestDataNetCDF:
1359
- @pytest.fixture(scope="class")
1360
- def data(self, draws, chains):
1361
- class Data:
1362
- # fake 8-school output
1363
- obj = {}
1364
- for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
1365
- obj[key] = np.random.randn(chains, draws, *shape)
1366
-
1367
- return Data
1368
-
1369
- def get_inference_data(self, data, eight_schools_params):
1370
- return from_dict(
1371
- posterior=data.obj,
1372
- posterior_predictive=data.obj,
1373
- sample_stats=data.obj,
1374
- prior=data.obj,
1375
- prior_predictive=data.obj,
1376
- sample_stats_prior=data.obj,
1377
- observed_data=eight_schools_params,
1378
- coords={"school": np.array(["a" * i for i in range(8)], dtype="U")},
1379
- dims={"theta": ["school"], "eta": ["school"]},
1380
- )
1381
-
1382
- def test_io_function(self, data, eight_schools_params):
1383
- # create inference data and assert all attributes are present
1384
- inference_data = self.get_inference_data( # pylint: disable=W0612
1385
- data, eight_schools_params
1386
- )
1387
- test_dict = {
1388
- "posterior": ["eta", "theta", "mu", "tau"],
1389
- "posterior_predictive": ["eta", "theta", "mu", "tau"],
1390
- "sample_stats": ["eta", "theta", "mu", "tau"],
1391
- "prior": ["eta", "theta", "mu", "tau"],
1392
- "prior_predictive": ["eta", "theta", "mu", "tau"],
1393
- "sample_stats_prior": ["eta", "theta", "mu", "tau"],
1394
- "observed_data": ["J", "y", "sigma"],
1395
- }
1396
- fails = check_multiple_attrs(test_dict, inference_data)
1397
- assert not fails
1398
-
1399
- # check filename does not exist and save InferenceData
1400
- here = os.path.dirname(os.path.abspath(__file__))
1401
- data_directory = os.path.join(here, "..", "saved_models")
1402
- filepath = os.path.join(data_directory, "io_function_testfile.nc")
1403
- # az -function
1404
- to_netcdf(inference_data, filepath)
1405
-
1406
- # Assert InferenceData has been saved correctly
1407
- assert os.path.exists(filepath)
1408
- assert os.path.getsize(filepath) > 0
1409
- inference_data2 = from_netcdf(filepath)
1410
- fails = check_multiple_attrs(test_dict, inference_data2)
1411
- assert not fails
1412
- os.remove(filepath)
1413
- assert not os.path.exists(filepath)
1414
-
1415
- @pytest.mark.parametrize("base_group", ["/", "test_group", "group/subgroup"])
1416
- @pytest.mark.parametrize("groups_arg", [False, True])
1417
- @pytest.mark.parametrize("compress", [True, False])
1418
- @pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
1419
- def test_io_method(self, data, eight_schools_params, groups_arg, base_group, compress, engine):
1420
- # create InferenceData and check it has been properly created
1421
- inference_data = self.get_inference_data( # pylint: disable=W0612
1422
- data, eight_schools_params
1423
- )
1424
- if engine == "h5netcdf":
1425
- try:
1426
- import h5netcdf # pylint: disable=unused-import
1427
- except ImportError:
1428
- pytest.skip("h5netcdf not installed")
1429
- elif engine == "netcdf4":
1430
- try:
1431
- import netCDF4 # pylint: disable=unused-import
1432
- except ImportError:
1433
- pytest.skip("netcdf4 not installed")
1434
- test_dict = {
1435
- "posterior": ["eta", "theta", "mu", "tau"],
1436
- "posterior_predictive": ["eta", "theta", "mu", "tau"],
1437
- "sample_stats": ["eta", "theta", "mu", "tau"],
1438
- "prior": ["eta", "theta", "mu", "tau"],
1439
- "prior_predictive": ["eta", "theta", "mu", "tau"],
1440
- "sample_stats_prior": ["eta", "theta", "mu", "tau"],
1441
- "observed_data": ["J", "y", "sigma"],
1442
- }
1443
- fails = check_multiple_attrs(test_dict, inference_data)
1444
- assert not fails
1445
-
1446
- # check filename does not exist and use to_netcdf method
1447
- here = os.path.dirname(os.path.abspath(__file__))
1448
- data_directory = os.path.join(here, "..", "saved_models")
1449
- filepath = os.path.join(data_directory, "io_method_testfile.nc")
1450
- assert not os.path.exists(filepath)
1451
- # InferenceData method
1452
- inference_data.to_netcdf(
1453
- filepath,
1454
- groups=("posterior", "observed_data") if groups_arg else None,
1455
- compress=compress,
1456
- base_group=base_group,
1457
- )
1458
-
1459
- # assert file has been saved correctly
1460
- assert os.path.exists(filepath)
1461
- assert os.path.getsize(filepath) > 0
1462
- inference_data2 = InferenceData.from_netcdf(filepath, base_group=base_group)
1463
- if groups_arg: # if groups arg, update test dict to contain only saved groups
1464
- test_dict = {
1465
- "posterior": ["eta", "theta", "mu", "tau"],
1466
- "observed_data": ["J", "y", "sigma"],
1467
- }
1468
- assert not hasattr(inference_data2, "sample_stats")
1469
- fails = check_multiple_attrs(test_dict, inference_data2)
1470
- assert not fails
1471
-
1472
- os.remove(filepath)
1473
- assert not os.path.exists(filepath)
1474
-
1475
- def test_empty_inference_data_object(self):
1476
- inference_data = InferenceData()
1477
- here = os.path.dirname(os.path.abspath(__file__))
1478
- data_directory = os.path.join(here, "..", "saved_models")
1479
- filepath = os.path.join(data_directory, "empty_test_file.nc")
1480
- assert not os.path.exists(filepath)
1481
- inference_data.to_netcdf(filepath)
1482
- assert os.path.exists(filepath)
1483
- os.remove(filepath)
1484
- assert not os.path.exists(filepath)
1485
-
1486
-
1487
- class TestJSON:
1488
- def test_json_converters(self, models):
1489
- idata = models.model_1
1490
-
1491
- filepath = os.path.realpath("test.json")
1492
- idata.to_json(filepath)
1493
-
1494
- idata_copy = from_json(filepath)
1495
- for group in idata._groups_all: # pylint: disable=protected-access
1496
- xr_data = getattr(idata, group)
1497
- test_xr_data = getattr(idata_copy, group)
1498
- assert xr_data.equals(test_xr_data)
1499
-
1500
- os.remove(filepath)
1501
- assert not os.path.exists(filepath)
1502
-
1503
-
1504
- class TestDataTree:
1505
- def test_datatree(self):
1506
- idata = load_arviz_data("centered_eight")
1507
- dt = idata.to_datatree()
1508
- idata_back = from_datatree(dt)
1509
- for group, ds in idata.items():
1510
- assert_identical(ds, idata_back[group])
1511
- assert all(group in dt.children for group in idata.groups())
1512
-
1513
- def test_datatree_attrs(self):
1514
- idata = load_arviz_data("centered_eight")
1515
- idata.attrs = {"not": "empty"}
1516
- assert idata.attrs
1517
- dt = idata.to_datatree()
1518
- idata_back = from_datatree(dt)
1519
- assert dt.attrs == idata.attrs
1520
- assert idata_back.attrs == idata.attrs
1521
-
1522
-
1523
- class TestConversions:
1524
- def test_id_conversion_idempotent(self):
1525
- stored = load_arviz_data("centered_eight")
1526
- inference_data = convert_to_inference_data(stored)
1527
- assert isinstance(inference_data, InferenceData)
1528
- assert set(inference_data.observed_data.obs.coords["school"].values) == {
1529
- "Hotchkiss",
1530
- "Mt. Hermon",
1531
- "Choate",
1532
- "Deerfield",
1533
- "Phillips Andover",
1534
- "St. Paul's",
1535
- "Lawrenceville",
1536
- "Phillips Exeter",
1537
- }
1538
- assert inference_data.posterior["theta"].dims == ("chain", "draw", "school")
1539
-
1540
- def test_dataset_conversion_idempotent(self):
1541
- inference_data = load_arviz_data("centered_eight")
1542
- data_set = convert_to_dataset(inference_data.posterior)
1543
- assert isinstance(data_set, xr.Dataset)
1544
- assert set(data_set.coords["school"].values) == {
1545
- "Hotchkiss",
1546
- "Mt. Hermon",
1547
- "Choate",
1548
- "Deerfield",
1549
- "Phillips Andover",
1550
- "St. Paul's",
1551
- "Lawrenceville",
1552
- "Phillips Exeter",
1553
- }
1554
- assert data_set["theta"].dims == ("chain", "draw", "school")
1555
-
1556
- def test_id_conversion_args(self):
1557
- stored = load_arviz_data("centered_eight")
1558
- IVIES = ["Yale", "Harvard", "MIT", "Princeton", "Cornell", "Dartmouth", "Columbia", "Brown"]
1559
- # test dictionary argument...
1560
- # I reverse engineered a dictionary out of the centered_eight
1561
- # data. That's what this block of code does.
1562
- d = stored.posterior.to_dict()
1563
- d = d["data_vars"]
1564
- test_dict = {} # type: Dict[str, np.ndarray]
1565
- for var_name in d:
1566
- data = d[var_name]["data"]
1567
- # this is a list of chains that is a list of samples...
1568
- chain_arrs = []
1569
- for chain in data: # list of samples
1570
- chain_arrs.append(np.array(chain))
1571
- data_arr = np.stack(chain_arrs)
1572
- test_dict[var_name] = data_arr
1573
-
1574
- inference_data = convert_to_inference_data(
1575
- test_dict, dims={"theta": ["Ivies"]}, coords={"Ivies": IVIES}
1576
- )
1577
-
1578
- assert isinstance(inference_data, InferenceData)
1579
- assert set(inference_data.posterior.coords["Ivies"].values) == set(IVIES)
1580
- assert inference_data.posterior["theta"].dims == ("chain", "draw", "Ivies")
1581
-
1582
-
1583
- class TestDataArrayToDataset:
1584
- def test_1d_dataset(self):
1585
- size = 100
1586
- dataset = convert_to_dataset(
1587
- xr.DataArray(np.random.randn(1, size), name="plot", dims=("chain", "draw"))
1588
- )
1589
- assert len(dataset.data_vars) == 1
1590
- assert "plot" in dataset.data_vars
1591
- assert dataset.chain.shape == (1,)
1592
- assert dataset.draw.shape == (size,)
1593
-
1594
- def test_nd_to_dataset(self):
1595
- shape = (1, 2, 3, 4, 5)
1596
- dataset = convert_to_dataset(
1597
- xr.DataArray(np.random.randn(*shape), dims=("chain", "draw", "dim_0", "dim_1", "dim_2"))
1598
- )
1599
- var_name = list(dataset.data_vars)[0]
1600
-
1601
- assert len(dataset.data_vars) == 1
1602
- assert dataset.chain.shape == shape[:1]
1603
- assert dataset.draw.shape == shape[1:2]
1604
- assert dataset[var_name].shape == shape
1605
-
1606
- def test_nd_to_inference_data(self):
1607
- shape = (1, 2, 3, 4, 5)
1608
- inference_data = convert_to_inference_data(
1609
- xr.DataArray(
1610
- np.random.randn(*shape), dims=("chain", "draw", "dim_0", "dim_1", "dim_2")
1611
- ),
1612
- group="prior",
1613
- )
1614
- var_name = list(inference_data.prior.data_vars)[0]
1615
-
1616
- assert hasattr(inference_data, "prior")
1617
- assert len(inference_data.prior.data_vars) == 1
1618
- assert inference_data.prior.chain.shape == shape[:1]
1619
- assert inference_data.prior.draw.shape == shape[1:2]
1620
- assert inference_data.prior[var_name].shape == shape
1621
-
1622
-
1623
- class TestExtractDataset:
1624
- def test_default(self):
1625
- idata = load_arviz_data("centered_eight")
1626
- post = extract(idata)
1627
- assert isinstance(post, xr.Dataset)
1628
- assert "sample" in post.dims
1629
- assert post.theta.size == (4 * 500 * 8)
1630
-
1631
- def test_seed(self):
1632
- idata = load_arviz_data("centered_eight")
1633
- post = extract(idata, rng=7)
1634
- post_pred = extract(idata, group="posterior_predictive", rng=7)
1635
- assert all(post.sample == post_pred.sample)
1636
-
1637
- def test_no_combine(self):
1638
- idata = load_arviz_data("centered_eight")
1639
- post = extract(idata, combined=False)
1640
- assert "sample" not in post.dims
1641
- assert post.sizes["chain"] == 4
1642
- assert post.sizes["draw"] == 500
1643
-
1644
- def test_var_name_group(self):
1645
- idata = load_arviz_data("centered_eight")
1646
- prior = extract(idata, group="prior", var_names="the", filter_vars="like")
1647
- assert {} == prior.attrs
1648
- assert "theta" in prior.name
1649
-
1650
- def test_keep_dataset(self):
1651
- idata = load_arviz_data("centered_eight")
1652
- prior = extract(
1653
- idata, group="prior", var_names="the", filter_vars="like", keep_dataset=True
1654
- )
1655
- assert prior.attrs == idata.prior.attrs
1656
- assert "theta" in prior.data_vars
1657
- assert "mu" not in prior.data_vars
1658
-
1659
- def test_subset_samples(self):
1660
- idata = load_arviz_data("centered_eight")
1661
- post = extract(idata, num_samples=10)
1662
- assert post.sizes["sample"] == 10
1663
- assert post.attrs == idata.posterior.attrs
1664
-
1665
-
1666
- def test_convert_to_inference_data_with_array_like():
1667
- class ArrayLike:
1668
- def __init__(self, data):
1669
- self._data = np.asarray(data)
1670
-
1671
- def __array__(self):
1672
- return self._data
1673
-
1674
- array_like = ArrayLike(np.random.randn(4, 100))
1675
- idata = convert_to_inference_data(array_like, group="posterior")
1676
-
1677
- assert hasattr(idata, "posterior")
1678
- assert "x" in idata.posterior.data_vars
1679
- assert idata.posterior["x"].shape == (4, 100)